mirror of
https://github.com/Relintai/rcpp_framework.git
synced 2024-11-14 04:57:21 +01:00
initial codestyle cleanup in the net folder.
This commit is contained in:
parent
8992cf49b3
commit
38a28ee9ce
File diff suppressed because it is too large
Load Diff
@ -30,54 +30,45 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/net/tcp_connection.h"
|
||||
#include "core/loops/timing_wheel.h"
|
||||
#include "core/net/tcp_connection.h"
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
#ifndef _WIN32
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <thread>
|
||||
#include <array>
|
||||
|
||||
#include <thread>
|
||||
|
||||
#ifdef USE_OPENSSL
|
||||
enum class SSLStatus
|
||||
{
|
||||
Handshaking,
|
||||
Connecting,
|
||||
Connected,
|
||||
DisConnecting,
|
||||
DisConnected
|
||||
|
||||
enum class SSLStatus {
|
||||
Handshaking,
|
||||
Connecting,
|
||||
Connected,
|
||||
DisConnecting,
|
||||
DisConnected
|
||||
};
|
||||
|
||||
class SSLContext;
|
||||
class SSLConn;
|
||||
|
||||
std::shared_ptr<SSLContext> newSSLContext(
|
||||
bool useOldTLS,
|
||||
bool validateCert,
|
||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
||||
std::shared_ptr<SSLContext> newSSLServerContext(
|
||||
const std::string &certPath,
|
||||
const std::string &keyPath,
|
||||
bool useOldTLS,
|
||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
||||
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
|
||||
// const std::string &certPath,
|
||||
// const std::string &keyPath);
|
||||
std::shared_ptr<SSLContext> newSSLContext(bool useOldTLS, bool validateCert, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
|
||||
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
|
||||
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx, const std::string &certPath, const std::string &keyPath);
|
||||
|
||||
#endif
|
||||
|
||||
class Channel;
|
||||
class Socket;
|
||||
class TcpServer;
|
||||
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
|
||||
|
||||
class TcpConnectionImpl : public TcpConnection,
|
||||
public std::enable_shared_from_this<TcpConnectionImpl>
|
||||
{
|
||||
friend class TcpServer;
|
||||
friend class TcpClient;
|
||||
friend void removeConnection(EventLoop *loop,
|
||||
const TcpConnectionPtr &conn);
|
||||
class TcpConnectionImpl : public TcpConnection, public std::enable_shared_from_this<TcpConnectionImpl> {
|
||||
friend class TcpServer;
|
||||
friend class TcpClient;
|
||||
|
||||
friend void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
|
||||
|
||||
protected:
|
||||
TcpConnectionImpl(const TcpConnectionImpl &) = delete;
|
||||
@ -86,278 +77,247 @@ protected:
|
||||
TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default;
|
||||
TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default;
|
||||
|
||||
public:
|
||||
class KickoffEntry
|
||||
{
|
||||
public:
|
||||
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn)
|
||||
: conn_(conn)
|
||||
{
|
||||
}
|
||||
void reset()
|
||||
{
|
||||
conn_.reset();
|
||||
}
|
||||
~KickoffEntry()
|
||||
{
|
||||
auto conn = conn_.lock();
|
||||
if (conn)
|
||||
{
|
||||
conn->forceClose();
|
||||
}
|
||||
}
|
||||
public:
|
||||
class KickoffEntry {
|
||||
public:
|
||||
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn) :
|
||||
conn_(conn) {
|
||||
}
|
||||
|
||||
private:
|
||||
std::weak_ptr<TcpConnection> conn_;
|
||||
};
|
||||
void reset() {
|
||||
conn_.reset();
|
||||
}
|
||||
|
||||
TcpConnectionImpl(EventLoop *loop,
|
||||
int socketfd,
|
||||
const InetAddress &localAddr,
|
||||
const InetAddress &peerAddr);
|
||||
~KickoffEntry() {
|
||||
auto conn = conn_.lock();
|
||||
if (conn) {
|
||||
conn->forceClose();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::weak_ptr<TcpConnection> conn_;
|
||||
};
|
||||
|
||||
TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr);
|
||||
#ifdef USE_OPENSSL
|
||||
TcpConnectionImpl(EventLoop *loop,
|
||||
int socketfd,
|
||||
const InetAddress &localAddr,
|
||||
const InetAddress &peerAddr,
|
||||
const std::shared_ptr<SSLContext> &ctxPtr,
|
||||
bool isServer = true,
|
||||
bool validateCert = true,
|
||||
const std::string &hostname = "");
|
||||
TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, const std::shared_ptr<SSLContext> &ctxPtr, bool isServer = true, bool validateCert = true, const std::string &hostname = "");
|
||||
#endif
|
||||
virtual ~TcpConnectionImpl();
|
||||
virtual void send(const char *msg, size_t len) override;
|
||||
virtual void send(const void *msg, size_t len) override;
|
||||
virtual void send(const std::string &msg) override;
|
||||
virtual void send(std::string &&msg) override;
|
||||
virtual void send(const MsgBuffer &buffer) override;
|
||||
virtual void send(MsgBuffer &&buffer) override;
|
||||
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
|
||||
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
|
||||
virtual void sendFile(const char *fileName,
|
||||
size_t offset = 0,
|
||||
size_t length = 0) override;
|
||||
|
||||
virtual const InetAddress &localAddr() const override
|
||||
{
|
||||
return localAddr_;
|
||||
}
|
||||
virtual const InetAddress &peerAddr() const override
|
||||
{
|
||||
return peerAddr_;
|
||||
}
|
||||
virtual ~TcpConnectionImpl();
|
||||
|
||||
virtual bool connected() const override
|
||||
{
|
||||
return status_ == ConnStatus::Connected;
|
||||
}
|
||||
virtual bool disconnected() const override
|
||||
{
|
||||
return status_ == ConnStatus::Disconnected;
|
||||
}
|
||||
virtual void send(const char *msg, size_t len) override;
|
||||
virtual void send(const void *msg, size_t len) override;
|
||||
virtual void send(const std::string &msg) override;
|
||||
virtual void send(std::string &&msg) override;
|
||||
virtual void send(const MsgBuffer &buffer) override;
|
||||
virtual void send(MsgBuffer &&buffer) override;
|
||||
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
|
||||
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
|
||||
virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) override;
|
||||
|
||||
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
|
||||
virtual MsgBuffer *getRecvBuffer() override
|
||||
{
|
||||
return &readBuffer_;
|
||||
}
|
||||
// set callbacks
|
||||
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
|
||||
size_t markLen) override
|
||||
{
|
||||
highWaterMarkCallback_ = cb;
|
||||
highWaterMarkLen_ = markLen;
|
||||
}
|
||||
virtual const InetAddress &localAddr() const override {
|
||||
return localAddr_;
|
||||
}
|
||||
|
||||
virtual void keepAlive() override
|
||||
{
|
||||
idleTimeout_ = 0;
|
||||
auto entry = kickoffEntry_.lock();
|
||||
if (entry)
|
||||
{
|
||||
entry->reset();
|
||||
}
|
||||
}
|
||||
virtual bool isKeepAlive() override
|
||||
{
|
||||
return idleTimeout_ == 0;
|
||||
}
|
||||
virtual void setTcpNoDelay(bool on) override;
|
||||
virtual void shutdown() override;
|
||||
virtual void forceClose() override;
|
||||
virtual EventLoop *getLoop() override
|
||||
{
|
||||
return loop_;
|
||||
}
|
||||
virtual const InetAddress &peerAddr() const override {
|
||||
return peerAddr_;
|
||||
}
|
||||
|
||||
virtual size_t bytesSent() const override
|
||||
{
|
||||
return bytesSent_;
|
||||
}
|
||||
virtual size_t bytesReceived() const override
|
||||
{
|
||||
return bytesReceived_;
|
||||
}
|
||||
virtual void startClientEncryption(
|
||||
std::function<void()> callback,
|
||||
bool useOldTLS = false,
|
||||
bool validateCert = true,
|
||||
std::string hostname = "",
|
||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
|
||||
{}) override;
|
||||
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
|
||||
std::function<void()> callback) override;
|
||||
virtual bool isSSLConnection() const override
|
||||
{
|
||||
return isEncrypted_;
|
||||
}
|
||||
virtual bool connected() const override {
|
||||
return status_ == ConnStatus::Connected;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Internal use only.
|
||||
virtual bool disconnected() const override {
|
||||
return status_ == ConnStatus::Disconnected;
|
||||
}
|
||||
|
||||
std::weak_ptr<KickoffEntry> kickoffEntry_;
|
||||
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
|
||||
size_t idleTimeout_{0};
|
||||
Date lastTimingWheelUpdateTime_;
|
||||
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
|
||||
virtual MsgBuffer *getRecvBuffer() override {
|
||||
return &readBuffer_;
|
||||
}
|
||||
|
||||
void enableKickingOff(size_t timeout,
|
||||
const std::shared_ptr<TimingWheel> &timingWheel)
|
||||
{
|
||||
assert(timingWheel);
|
||||
assert(timingWheel->getLoop() == loop_);
|
||||
assert(timeout > 0);
|
||||
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
|
||||
kickoffEntry_ = entry;
|
||||
timingWheelWeakPtr_ = timingWheel;
|
||||
idleTimeout_ = timeout;
|
||||
timingWheel->insertEntry(timeout, entry);
|
||||
}
|
||||
void extendLife();
|
||||
// set callbacks
|
||||
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
|
||||
size_t markLen) override {
|
||||
highWaterMarkCallback_ = cb;
|
||||
highWaterMarkLen_ = markLen;
|
||||
}
|
||||
|
||||
virtual void keepAlive() override {
|
||||
idleTimeout_ = 0;
|
||||
auto entry = kickoffEntry_.lock();
|
||||
if (entry) {
|
||||
entry->reset();
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool isKeepAlive() override {
|
||||
return idleTimeout_ == 0;
|
||||
}
|
||||
|
||||
virtual void setTcpNoDelay(bool on) override;
|
||||
virtual void shutdown() override;
|
||||
virtual void forceClose() override;
|
||||
|
||||
virtual EventLoop *getLoop() override {
|
||||
return loop_;
|
||||
}
|
||||
|
||||
virtual size_t bytesSent() const override {
|
||||
return bytesSent_;
|
||||
}
|
||||
|
||||
virtual size_t bytesReceived() const override {
|
||||
return bytesReceived_;
|
||||
}
|
||||
|
||||
virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) override;
|
||||
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) override;
|
||||
virtual bool isSSLConnection() const override {
|
||||
return isEncrypted_;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Internal use only.
|
||||
|
||||
std::weak_ptr<KickoffEntry> kickoffEntry_;
|
||||
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
|
||||
size_t idleTimeout_{ 0 };
|
||||
Date lastTimingWheelUpdateTime_;
|
||||
|
||||
void enableKickingOff(size_t timeout,
|
||||
const std::shared_ptr<TimingWheel> &timingWheel) {
|
||||
assert(timingWheel);
|
||||
assert(timingWheel->getLoop() == loop_);
|
||||
assert(timeout > 0);
|
||||
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
|
||||
kickoffEntry_ = entry;
|
||||
timingWheelWeakPtr_ = timingWheel;
|
||||
idleTimeout_ = timeout;
|
||||
timingWheel->insertEntry(timeout, entry);
|
||||
}
|
||||
|
||||
void extendLife();
|
||||
#ifndef _WIN32
|
||||
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
|
||||
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
|
||||
#else
|
||||
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
|
||||
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
|
||||
#endif
|
||||
void setRecvMsgCallback(const RecvMessageCallback &cb)
|
||||
{
|
||||
recvMsgCallback_ = cb;
|
||||
}
|
||||
void setConnectionCallback(const ConnectionCallback &cb)
|
||||
{
|
||||
connectionCallback_ = cb;
|
||||
}
|
||||
void setWriteCompleteCallback(const WriteCompleteCallback &cb)
|
||||
{
|
||||
writeCompleteCallback_ = cb;
|
||||
}
|
||||
void setCloseCallback(const CloseCallback &cb)
|
||||
{
|
||||
closeCallback_ = cb;
|
||||
}
|
||||
void setSSLErrorCallback(const SSLErrorCallback &cb)
|
||||
{
|
||||
sslErrorCallback_ = cb;
|
||||
}
|
||||
void setRecvMsgCallback(const RecvMessageCallback &cb) {
|
||||
recvMsgCallback_ = cb;
|
||||
}
|
||||
void setConnectionCallback(const ConnectionCallback &cb) {
|
||||
connectionCallback_ = cb;
|
||||
}
|
||||
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
||||
writeCompleteCallback_ = cb;
|
||||
}
|
||||
void setCloseCallback(const CloseCallback &cb) {
|
||||
closeCallback_ = cb;
|
||||
}
|
||||
void setSSLErrorCallback(const SSLErrorCallback &cb) {
|
||||
sslErrorCallback_ = cb;
|
||||
}
|
||||
|
||||
void connectDestroyed();
|
||||
virtual void connectEstablished();
|
||||
void connectDestroyed();
|
||||
virtual void connectEstablished();
|
||||
|
||||
protected:
|
||||
struct BufferNode
|
||||
{
|
||||
protected:
|
||||
struct BufferNode {
|
||||
#ifndef _WIN32
|
||||
int sendFd_{-1};
|
||||
off_t offset_;
|
||||
int sendFd_{ -1 };
|
||||
off_t offset_;
|
||||
#else
|
||||
FILE *sendFp_{nullptr};
|
||||
long long offset_;
|
||||
FILE *sendFp_{ nullptr };
|
||||
long long offset_;
|
||||
#endif
|
||||
ssize_t fileBytesToSend_;
|
||||
std::shared_ptr<MsgBuffer> msgBuffer_;
|
||||
~BufferNode()
|
||||
{
|
||||
ssize_t fileBytesToSend_;
|
||||
std::shared_ptr<MsgBuffer> msgBuffer_;
|
||||
|
||||
~BufferNode() {
|
||||
#ifndef _WIN32
|
||||
if (sendFd_ >= 0)
|
||||
close(sendFd_);
|
||||
if (sendFd_ >= 0)
|
||||
close(sendFd_);
|
||||
#else
|
||||
if (sendFp_)
|
||||
fclose(sendFp_);
|
||||
if (sendFp_)
|
||||
fclose(sendFp_);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
using BufferNodePtr = std::shared_ptr<BufferNode>;
|
||||
enum class ConnStatus
|
||||
{
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected,
|
||||
Disconnecting
|
||||
};
|
||||
bool isEncrypted_{false};
|
||||
EventLoop *loop_;
|
||||
std::unique_ptr<Channel> ioChannelPtr_;
|
||||
std::unique_ptr<Socket> socketPtr_;
|
||||
MsgBuffer readBuffer_;
|
||||
std::list<BufferNodePtr> writeBufferList_;
|
||||
void readCallback();
|
||||
void writeCallback();
|
||||
InetAddress localAddr_, peerAddr_;
|
||||
ConnStatus status_{ConnStatus::Connecting};
|
||||
// callbacks
|
||||
RecvMessageCallback recvMsgCallback_;
|
||||
ConnectionCallback connectionCallback_;
|
||||
CloseCallback closeCallback_;
|
||||
WriteCompleteCallback writeCompleteCallback_;
|
||||
HighWaterMarkCallback highWaterMarkCallback_;
|
||||
SSLErrorCallback sslErrorCallback_;
|
||||
void handleClose();
|
||||
void handleError();
|
||||
// virtual void sendInLoop(const std::string &msg);
|
||||
}
|
||||
};
|
||||
|
||||
void sendFileInLoop(const BufferNodePtr &file);
|
||||
using BufferNodePtr = std::shared_ptr<BufferNode>;
|
||||
enum class ConnStatus {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected,
|
||||
Disconnecting
|
||||
};
|
||||
|
||||
bool isEncrypted_{ false };
|
||||
EventLoop *loop_;
|
||||
std::unique_ptr<Channel> ioChannelPtr_;
|
||||
std::unique_ptr<Socket> socketPtr_;
|
||||
MsgBuffer readBuffer_;
|
||||
std::list<BufferNodePtr> writeBufferList_;
|
||||
|
||||
void readCallback();
|
||||
void writeCallback();
|
||||
|
||||
InetAddress localAddr_, peerAddr_;
|
||||
ConnStatus status_{ ConnStatus::Connecting };
|
||||
// callbacks
|
||||
RecvMessageCallback recvMsgCallback_;
|
||||
ConnectionCallback connectionCallback_;
|
||||
CloseCallback closeCallback_;
|
||||
WriteCompleteCallback writeCompleteCallback_;
|
||||
HighWaterMarkCallback highWaterMarkCallback_;
|
||||
SSLErrorCallback sslErrorCallback_;
|
||||
|
||||
void handleClose();
|
||||
void handleError();
|
||||
// virtual void sendInLoop(const std::string &msg);
|
||||
|
||||
void sendFileInLoop(const BufferNodePtr &file);
|
||||
#ifndef _WIN32
|
||||
void sendInLoop(const void *buffer, size_t length);
|
||||
ssize_t writeInLoop(const void *buffer, size_t length);
|
||||
void sendInLoop(const void *buffer, size_t length);
|
||||
ssize_t writeInLoop(const void *buffer, size_t length);
|
||||
#else
|
||||
void sendInLoop(const char *buffer, size_t length);
|
||||
ssize_t writeInLoop(const char *buffer, size_t length);
|
||||
void sendInLoop(const char *buffer, size_t length);
|
||||
ssize_t writeInLoop(const char *buffer, size_t length);
|
||||
#endif
|
||||
size_t highWaterMarkLen_;
|
||||
std::string name_;
|
||||
size_t highWaterMarkLen_;
|
||||
std::string name_;
|
||||
|
||||
uint64_t sendNum_{0};
|
||||
std::mutex sendNumMutex_;
|
||||
uint64_t sendNum_{ 0 };
|
||||
std::mutex sendNumMutex_;
|
||||
|
||||
size_t bytesSent_{0};
|
||||
size_t bytesReceived_{0};
|
||||
size_t bytesSent_{ 0 };
|
||||
size_t bytesReceived_{ 0 };
|
||||
|
||||
std::unique_ptr<std::vector<char>> fileBufferPtr_;
|
||||
std::unique_ptr<std::vector<char> > fileBufferPtr_;
|
||||
|
||||
#ifdef USE_OPENSSL
|
||||
private:
|
||||
void doHandshaking();
|
||||
bool validatePeerCertificate();
|
||||
struct SSLEncryption
|
||||
{
|
||||
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
|
||||
// OpenSSL
|
||||
std::shared_ptr<SSLContext> sslCtxPtr_;
|
||||
std::unique_ptr<SSLConn> sslPtr_;
|
||||
std::unique_ptr<std::array<char, 8192>> sendBufferPtr_;
|
||||
bool isServer_{false};
|
||||
bool isUpgrade_{false};
|
||||
std::function<void()> upgradeCallback_;
|
||||
std::string hostname_;
|
||||
};
|
||||
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
|
||||
void startClientEncryptionInLoop(
|
||||
std::function<void()> &&callback,
|
||||
bool useOldTLS,
|
||||
bool validateCert,
|
||||
const std::string &hostname,
|
||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
||||
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
|
||||
std::function<void()> &&callback);
|
||||
private:
|
||||
void doHandshaking();
|
||||
bool validatePeerCertificate();
|
||||
|
||||
struct SSLEncryption {
|
||||
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
|
||||
// OpenSSL
|
||||
std::shared_ptr<SSLContext> sslCtxPtr_;
|
||||
std::unique_ptr<SSLConn> sslPtr_;
|
||||
std::unique_ptr<std::array<char, 8192> > sendBufferPtr_;
|
||||
bool isServer_{ false };
|
||||
bool isUpgrade_{ false };
|
||||
std::function<void()> upgradeCallback_;
|
||||
std::string hostname_;
|
||||
};
|
||||
|
||||
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
|
||||
|
||||
void startClientEncryptionInLoop(std::function<void()> &&callback, bool useOldTLS, bool validateCert, const std::string &hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
|
||||
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx, std::function<void()> &&callback);
|
||||
#endif
|
||||
};
|
||||
|
||||
|
@ -35,6 +35,7 @@
|
||||
Connector::Connector(EventLoop *loop, const InetAddress &addr, bool retry) :
|
||||
loop_(loop), serverAddr_(addr), retry_(retry) {
|
||||
}
|
||||
|
||||
Connector::Connector(EventLoop *loop, InetAddress &&addr, bool retry) :
|
||||
loop_(loop), serverAddr_(std::move(addr)), retry_(retry) {
|
||||
}
|
||||
@ -43,8 +44,10 @@ void Connector::start() {
|
||||
connect_ = true;
|
||||
loop_->runInLoop([this]() { startInLoop(); });
|
||||
}
|
||||
|
||||
void Connector::restart() {
|
||||
}
|
||||
|
||||
void Connector::stop() {
|
||||
}
|
||||
|
||||
@ -57,6 +60,7 @@ void Connector::startInLoop() {
|
||||
LOG_DEBUG << "do not connect";
|
||||
}
|
||||
}
|
||||
|
||||
void Connector::connect() {
|
||||
int sockfd = Socket::createNonblockingSocketOrDie(serverAddr_.family());
|
||||
errno = 0;
|
||||
|
@ -46,23 +46,30 @@ protected:
|
||||
public:
|
||||
using NewConnectionCallback = std::function<void(int sockfd)>;
|
||||
using ConnectionErrorCallback = std::function<void()>;
|
||||
|
||||
Connector(EventLoop *loop, const InetAddress &addr, bool retry = true);
|
||||
Connector(EventLoop *loop, InetAddress &&addr, bool retry = true);
|
||||
|
||||
void setNewConnectionCallback(const NewConnectionCallback &cb) {
|
||||
newConnectionCallback_ = cb;
|
||||
}
|
||||
|
||||
void setNewConnectionCallback(NewConnectionCallback &&cb) {
|
||||
newConnectionCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
void setErrorCallback(const ConnectionErrorCallback &cb) {
|
||||
errorCallback_ = cb;
|
||||
}
|
||||
|
||||
void setErrorCallback(ConnectionErrorCallback &&cb) {
|
||||
errorCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
const InetAddress &serverAddress() const {
|
||||
return serverAddr_;
|
||||
}
|
||||
|
||||
void start();
|
||||
void restart();
|
||||
void stop();
|
||||
@ -70,11 +77,13 @@ public:
|
||||
private:
|
||||
NewConnectionCallback newConnectionCallback_;
|
||||
ConnectionErrorCallback errorCallback_;
|
||||
|
||||
enum class Status {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected
|
||||
};
|
||||
|
||||
static constexpr int kMaxRetryDelayMs = 30 * 1000;
|
||||
static constexpr int kInitRetryDelayMs = 500;
|
||||
std::shared_ptr<Channel> channelPtr_;
|
||||
|
@ -98,8 +98,7 @@ InetAddress::InetAddress(uint16_t port, bool loopbackOnly, bool ipv6) :
|
||||
isUnspecified_ = false;
|
||||
}
|
||||
|
||||
InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) :
|
||||
isIpV6_(ipv6) {
|
||||
InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) : isIpV6_(ipv6) {
|
||||
if (ipv6) {
|
||||
memset(&addr6_, 0, sizeof(addr6_));
|
||||
addr6_.sin6_family = AF_INET6;
|
||||
|
@ -38,195 +38,76 @@ using sa_family_t = unsigned short;
|
||||
using in_addr_t = uint32_t;
|
||||
using uint16_t = unsigned short;
|
||||
#else
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#endif
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
|
||||
/**
|
||||
* @brief Wrapper of sockaddr_in. This is an POD interface class.
|
||||
*
|
||||
*/
|
||||
class InetAddress
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs an endpoint with given port number. Mostly used in
|
||||
* TcpServer listening.
|
||||
*
|
||||
* @param port
|
||||
* @param loopbackOnly
|
||||
* @param ipv6
|
||||
*/
|
||||
InetAddress(uint16_t port = 0,
|
||||
bool loopbackOnly = false,
|
||||
bool ipv6 = false);
|
||||
class InetAddress {
|
||||
public:
|
||||
InetAddress(uint16_t port = 0, bool loopbackOnly = false, bool ipv6 = false);
|
||||
|
||||
/**
|
||||
* @brief Constructs an endpoint with given ip and port.
|
||||
*
|
||||
* @param ip A IPv4 or IPv6 address.
|
||||
* @param port
|
||||
* @param ipv6
|
||||
*/
|
||||
InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
|
||||
InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
|
||||
|
||||
/**
|
||||
* @brief Constructs an endpoint with given struct `sockaddr_in`. Mostly
|
||||
* used when accepting new connections
|
||||
*
|
||||
* @param addr
|
||||
*/
|
||||
explicit InetAddress(const struct sockaddr_in &addr)
|
||||
: addr_(addr), isUnspecified_(false)
|
||||
{
|
||||
}
|
||||
explicit InetAddress(const struct sockaddr_in &addr) :
|
||||
addr_(addr), isUnspecified_(false) {
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constructs an IPv6 endpoint with given struct `sockaddr_in6`.
|
||||
* Mostly used when accepting new connections
|
||||
*
|
||||
* @param addr
|
||||
*/
|
||||
explicit InetAddress(const struct sockaddr_in6 &addr)
|
||||
: addr6_(addr), isIpV6_(true), isUnspecified_(false)
|
||||
{
|
||||
}
|
||||
explicit InetAddress(const struct sockaddr_in6 &addr) :
|
||||
addr6_(addr), isIpV6_(true), isUnspecified_(false) {
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the sin_family of the endpoint.
|
||||
*
|
||||
* @return sa_family_t
|
||||
*/
|
||||
sa_family_t family() const
|
||||
{
|
||||
return addr_.sin_family;
|
||||
}
|
||||
sa_family_t family() const {
|
||||
return addr_.sin_family;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the IP string of the endpoint.
|
||||
*
|
||||
* @return std::string
|
||||
*/
|
||||
std::string toIp() const;
|
||||
std::string toIp() const;
|
||||
std::string toIpPort() const;
|
||||
uint16_t toPort() const;
|
||||
|
||||
/**
|
||||
* @brief Return the IP and port string of the endpoint.
|
||||
*
|
||||
* @return std::string
|
||||
*/
|
||||
std::string toIpPort() const;
|
||||
bool isIpV6() const {
|
||||
return isIpV6_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the port number of the endpoint.
|
||||
*
|
||||
* @return uint16_t
|
||||
*/
|
||||
uint16_t toPort() const;
|
||||
bool isIntranetIp() const;
|
||||
bool isLoopbackIp() const;
|
||||
|
||||
/**
|
||||
* @brief Check if the endpoint is IPv4 or IPv6.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
bool isIpV6() const
|
||||
{
|
||||
return isIpV6_;
|
||||
}
|
||||
const struct sockaddr *getSockAddr() const {
|
||||
return static_cast<const struct sockaddr *>((void *)(&addr6_));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return true if the endpoint is an intranet endpoint.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
bool isIntranetIp() const;
|
||||
void setSockAddrInet6(const struct sockaddr_in6 &addr6) {
|
||||
addr6_ = addr6;
|
||||
isIpV6_ = (addr6_.sin6_family == AF_INET6);
|
||||
isUnspecified_ = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return true if the endpoint is a loopback endpoint.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
bool isLoopbackIp() const;
|
||||
uint32_t ipNetEndian() const;
|
||||
const uint32_t *ip6NetEndian() const;
|
||||
|
||||
/**
|
||||
* @brief Get the pointer to the sockaddr struct.
|
||||
*
|
||||
* @return const struct sockaddr*
|
||||
*/
|
||||
const struct sockaddr *getSockAddr() const
|
||||
{
|
||||
return static_cast<const struct sockaddr *>((void *)(&addr6_));
|
||||
}
|
||||
uint16_t portNetEndian() const {
|
||||
return addr_.sin_port;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the sockaddr_in6 struct in the endpoint.
|
||||
*
|
||||
* @param addr6
|
||||
*/
|
||||
void setSockAddrInet6(const struct sockaddr_in6 &addr6)
|
||||
{
|
||||
addr6_ = addr6;
|
||||
isIpV6_ = (addr6_.sin6_family == AF_INET6);
|
||||
isUnspecified_ = false;
|
||||
}
|
||||
void setPortNetEndian(uint16_t port) {
|
||||
addr_.sin_port = port;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the integer value of the IP(v4) in net endian byte order.
|
||||
*
|
||||
* @return uint32_t
|
||||
*/
|
||||
uint32_t ipNetEndian() const;
|
||||
inline bool isUnspecified() const {
|
||||
return isUnspecified_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the pointer to the integer value of the IP(v6) in net
|
||||
* endian byte order.
|
||||
*
|
||||
* @return const uint32_t*
|
||||
*/
|
||||
const uint32_t *ip6NetEndian() const;
|
||||
private:
|
||||
union {
|
||||
struct sockaddr_in addr_;
|
||||
struct sockaddr_in6 addr6_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Return the port number in net endian byte order.
|
||||
*
|
||||
* @return uint16_t
|
||||
*/
|
||||
uint16_t portNetEndian() const
|
||||
{
|
||||
return addr_.sin_port;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the port number in net endian byte order.
|
||||
*
|
||||
* @param port
|
||||
*/
|
||||
void setPortNetEndian(uint16_t port)
|
||||
{
|
||||
addr_.sin_port = port;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return true if the address is not initalized.
|
||||
*/
|
||||
inline bool isUnspecified() const
|
||||
{
|
||||
return isUnspecified_;
|
||||
}
|
||||
|
||||
private:
|
||||
union
|
||||
{
|
||||
struct sockaddr_in addr_;
|
||||
struct sockaddr_in6 addr6_;
|
||||
};
|
||||
bool isIpV6_{false};
|
||||
bool isUnspecified_{true};
|
||||
bool isIpV6_{ false };
|
||||
bool isUnspecified_{ true };
|
||||
};
|
||||
|
||||
#endif // MUDUO_NET_INETADDRESS_H
|
||||
#endif // MUDUO_NET_INETADDRESS_H
|
||||
|
@ -34,44 +34,20 @@
|
||||
#include "core/loops/event_loop.h"
|
||||
#include "core/net/inet_address.h"
|
||||
|
||||
/**
|
||||
* @brief This class represents an asynchronous DNS resolver.
|
||||
* @note Although the c-ares library is not essential, it is recommended to
|
||||
* install it for higher performance
|
||||
*/
|
||||
//make it a reference
|
||||
|
||||
class Resolver
|
||||
{
|
||||
public:
|
||||
using Callback = std::function<void(const InetAddress&)>;
|
||||
|
||||
/**
|
||||
* @brief Create a new DNS resolver.
|
||||
*
|
||||
* @param loop The event loop in which the DNS resolver runs.
|
||||
* @param timeout The timeout in seconds for DNS.
|
||||
* @return std::shared_ptr<Resolver>
|
||||
*/
|
||||
static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr,
|
||||
size_t timeout = 60);
|
||||
static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr, size_t timeout = 60);
|
||||
|
||||
/**
|
||||
* @brief Resolve an address asynchronously.
|
||||
*
|
||||
* @param hostname
|
||||
* @param callback
|
||||
*/
|
||||
virtual void resolve(const std::string& hostname,
|
||||
const Callback& callback) = 0;
|
||||
virtual void resolve(const std::string& hostname, const Callback& callback) = 0;
|
||||
|
||||
virtual ~Resolver()
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check whether the c-ares library is used.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
static bool isCAresUsed();
|
||||
};
|
||||
|
@ -29,8 +29,8 @@
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include "ares_resolver.h"
|
||||
#include <ares.h>
|
||||
#include "core/loops/channel.h"
|
||||
#include <ares.h>
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#else
|
||||
@ -67,14 +67,14 @@ bool Resolver::isCAresUsed() {
|
||||
AresResolver::LibraryInitializer::LibraryInitializer() {
|
||||
ares_library_init(ARES_LIB_INIT_ALL);
|
||||
}
|
||||
|
||||
AresResolver::LibraryInitializer::~LibraryInitializer() {
|
||||
ares_library_cleanup();
|
||||
}
|
||||
|
||||
AresResolver::LibraryInitializer AresResolver::libraryInitializer_;
|
||||
|
||||
std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop,
|
||||
size_t timeout) {
|
||||
std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop, size_t timeout) {
|
||||
return std::make_shared<AresResolver>(loop, timeout);
|
||||
}
|
||||
|
||||
@ -84,6 +84,7 @@ AresResolver::AresResolver(EventLoop *loop, size_t timeout) :
|
||||
loop_ = getLoop();
|
||||
}
|
||||
}
|
||||
|
||||
void AresResolver::init() {
|
||||
if (!ctx_) {
|
||||
struct ares_options options;
|
||||
@ -108,14 +109,16 @@ void AresResolver::init() {
|
||||
this);
|
||||
}
|
||||
}
|
||||
|
||||
AresResolver::~AresResolver() {
|
||||
if (ctx_)
|
||||
ares_destroy(ctx_);
|
||||
}
|
||||
|
||||
void AresResolver::resolveInLoop(const std::string &hostname,
|
||||
const Callback &cb) {
|
||||
void AresResolver::resolveInLoop(const std::string &hostname, const Callback &cb) {
|
||||
|
||||
loop_->assertInLoopThread();
|
||||
|
||||
#ifdef _WIN32
|
||||
if (hostname == "localhost") {
|
||||
const static InetAddress localhost_{ "127.0.0.1", 0 };
|
||||
@ -123,20 +126,17 @@ void AresResolver::resolveInLoop(const std::string &hostname,
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
init();
|
||||
QueryData *queryData = new QueryData(this, cb, hostname);
|
||||
ares_gethostbyname(ctx_,
|
||||
hostname.c_str(),
|
||||
AF_INET,
|
||||
&AresResolver::ares_hostcallback_,
|
||||
queryData);
|
||||
ares_gethostbyname(ctx_, hostname.c_str(), AF_INET, &AresResolver::ares_hostcallback_, queryData);
|
||||
struct timeval tv;
|
||||
struct timeval *tvp = ares_timeout(ctx_, NULL, &tv);
|
||||
double timeout = getSeconds(tvp);
|
||||
|
||||
// LOG_DEBUG << "timeout " << timeout << " active " << timerActive_;
|
||||
if (!timerActive_ && timeout >= 0.0) {
|
||||
loop_->runAfter(timeout,
|
||||
std::bind(&AresResolver::onTimer, shared_from_this()));
|
||||
loop_->runAfter(timeout, std::bind(&AresResolver::onTimer, shared_from_this()));
|
||||
timerActive_ = true;
|
||||
}
|
||||
return;
|
||||
@ -161,18 +161,18 @@ void AresResolver::onTimer() {
|
||||
}
|
||||
}
|
||||
|
||||
void AresResolver::onQueryResult(int status,
|
||||
struct hostent *result,
|
||||
const std::string &hostname,
|
||||
const Callback &callback) {
|
||||
void AresResolver::onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback) {
|
||||
|
||||
LOG_TRACE << "onQueryResult " << status;
|
||||
struct sockaddr_in addr;
|
||||
memset(&addr, 0, sizeof addr);
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = 0;
|
||||
|
||||
if (result) {
|
||||
addr.sin_addr = *reinterpret_cast<in_addr *>(result->h_addr);
|
||||
}
|
||||
|
||||
InetAddress inet(addr);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(globalMutex());
|
||||
@ -180,6 +180,7 @@ void AresResolver::onQueryResult(int status,
|
||||
addrItem.first = addr.sin_addr;
|
||||
addrItem.second = Date::date();
|
||||
}
|
||||
|
||||
callback(inet);
|
||||
}
|
||||
|
||||
@ -198,6 +199,7 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
|
||||
loop_->assertInLoopThread();
|
||||
ChannelList::iterator it = channels_.find(sockfd);
|
||||
assert(it != channels_.end());
|
||||
|
||||
if (read) {
|
||||
// update
|
||||
// if (write) { } else { }
|
||||
@ -209,17 +211,12 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
|
||||
}
|
||||
}
|
||||
|
||||
void AresResolver::ares_hostcallback_(void *data,
|
||||
int status,
|
||||
int timeouts,
|
||||
struct hostent *hostent) {
|
||||
void AresResolver::ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent) {
|
||||
(void)timeouts;
|
||||
QueryData *query = static_cast<QueryData *>(data);
|
||||
|
||||
query->owner_->onQueryResult(status,
|
||||
hostent,
|
||||
query->hostname_,
|
||||
query->callback_);
|
||||
query->owner_->onQueryResult(status, hostent, query->hostname_, query->callback_);
|
||||
|
||||
delete query;
|
||||
}
|
||||
|
||||
@ -242,6 +239,7 @@ void AresResolver::ares_sock_statecallback_(void *data,
|
||||
#endif
|
||||
int read,
|
||||
int write) {
|
||||
|
||||
LOG_TRACE << "sockfd=" << sockfd << " read=" << read << " write=" << write;
|
||||
static_cast<AresResolver *>(data)->onSockStateChange(sockfd, read, write);
|
||||
}
|
||||
|
@ -36,14 +36,15 @@
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
// Resolver will be a ref
|
||||
|
||||
extern "C" {
|
||||
struct hostent;
|
||||
struct ares_channeldata;
|
||||
using ares_channel = struct ares_channeldata *;
|
||||
}
|
||||
|
||||
class AresResolver : public Resolver,
|
||||
public std::enable_shared_from_this<AresResolver> {
|
||||
class AresResolver : public Resolver, public std::enable_shared_from_this<AresResolver> {
|
||||
protected:
|
||||
AresResolver(const AresResolver &) = delete;
|
||||
AresResolver &operator=(const AresResolver &) = delete;
|
||||
@ -55,9 +56,9 @@ public:
|
||||
AresResolver(EventLoop *loop, size_t timeout);
|
||||
~AresResolver();
|
||||
|
||||
virtual void resolve(const std::string &hostname,
|
||||
const Callback &cb) override {
|
||||
virtual void resolve(const std::string &hostname, const Callback &cb) override {
|
||||
bool cached = false;
|
||||
|
||||
InetAddress inet;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(globalMutex());
|
||||
@ -76,10 +77,12 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (cached) {
|
||||
cb(inet);
|
||||
return;
|
||||
}
|
||||
|
||||
if (loop_->isInLoopThread()) {
|
||||
resolveInLoop(hostname, cb);
|
||||
} else {
|
||||
@ -94,14 +97,15 @@ private:
|
||||
AresResolver *owner_;
|
||||
Callback callback_;
|
||||
std::string hostname_;
|
||||
QueryData(AresResolver *o,
|
||||
const Callback &cb,
|
||||
const std::string &hostname) :
|
||||
|
||||
QueryData(AresResolver *o, const Callback &cb, const std::string &hostname) :
|
||||
owner_(o), callback_(cb), hostname_(hostname) {
|
||||
}
|
||||
};
|
||||
|
||||
void resolveInLoop(const std::string &hostname, const Callback &cb);
|
||||
void init();
|
||||
|
||||
EventLoop *loop_;
|
||||
ares_channel ctx_{ nullptr };
|
||||
bool timerActive_{ false };
|
||||
@ -109,36 +113,33 @@ private:
|
||||
ChannelList channels_;
|
||||
static std::unordered_map<std::string,
|
||||
std::pair<struct in_addr, Date> > &
|
||||
|
||||
globalCache() {
|
||||
static std::unordered_map<std::string,
|
||||
std::pair<struct in_addr, Date> >
|
||||
dnsCache;
|
||||
static std::unordered_map<std::string, std::pair<struct in_addr, Date> > dnsCache;
|
||||
return dnsCache;
|
||||
}
|
||||
|
||||
static std::mutex &globalMutex() {
|
||||
static std::mutex mutex_;
|
||||
return mutex_;
|
||||
}
|
||||
|
||||
static EventLoop *getLoop() {
|
||||
static EventLoopThread loopThread;
|
||||
loopThread.run();
|
||||
return loopThread.getLoop();
|
||||
}
|
||||
|
||||
const size_t timeout_{ 60 };
|
||||
|
||||
void onRead(int sockfd);
|
||||
void onTimer();
|
||||
void onQueryResult(int status,
|
||||
struct hostent *result,
|
||||
const std::string &hostname,
|
||||
const Callback &callback);
|
||||
void onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback);
|
||||
void onSockCreate(int sockfd, int type);
|
||||
void onSockStateChange(int sockfd, bool read, bool write);
|
||||
|
||||
static void ares_hostcallback_(void *data,
|
||||
int status,
|
||||
int timeouts,
|
||||
struct hostent *hostent);
|
||||
static void ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent);
|
||||
|
||||
#ifdef _WIN32
|
||||
static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data);
|
||||
#else
|
||||
@ -152,9 +153,11 @@ private:
|
||||
#endif
|
||||
int read,
|
||||
int write);
|
||||
|
||||
struct LibraryInitializer {
|
||||
LibraryInitializer();
|
||||
~LibraryInitializer();
|
||||
};
|
||||
|
||||
static LibraryInitializer libraryInitializer_;
|
||||
};
|
||||
|
@ -44,9 +44,11 @@ std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *,
|
||||
size_t timeout) {
|
||||
return std::make_shared<NormalResolver>(timeout);
|
||||
}
|
||||
|
||||
bool Resolver::isCAresUsed() {
|
||||
return false;
|
||||
}
|
||||
|
||||
void NormalResolver::resolve(const std::string &hostname,
|
||||
const Callback &callback) {
|
||||
{
|
||||
|
@ -35,10 +35,12 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
//Resolver will be a ref
|
||||
|
||||
constexpr size_t kResolveBufferLength{ 16 * 1024 };
|
||||
|
||||
class NormalResolver : public Resolver,
|
||||
public std::enable_shared_from_this<NormalResolver> {
|
||||
class NormalResolver : public Resolver, public std::enable_shared_from_this<NormalResolver> {
|
||||
|
||||
protected:
|
||||
NormalResolver(const NormalResolver &) = delete;
|
||||
NormalResolver &operator=(const NormalResolver &) = delete;
|
||||
@ -56,25 +58,23 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
static std::unordered_map<std::string,
|
||||
std::pair<InetAddress, Date> > &
|
||||
globalCache() {
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
std::pair<InetAddress, Date> >
|
||||
dnsCache_;
|
||||
static std::unordered_map<std::string, std::pair<InetAddress, Date> > &globalCache() {
|
||||
static std::unordered_map<std::string, std::pair<InetAddress, Date> > dnsCache_;
|
||||
return dnsCache_;
|
||||
}
|
||||
|
||||
static std::mutex &globalMutex() {
|
||||
static std::mutex mutex_;
|
||||
return mutex_;
|
||||
}
|
||||
|
||||
static ConcurrentTaskQueue &concurrentTaskQueue() {
|
||||
static ConcurrentTaskQueue queue(
|
||||
std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(),
|
||||
"Dns Queue");
|
||||
return queue;
|
||||
}
|
||||
|
||||
const size_t timeout_;
|
||||
std::vector<char> resolveBuffer_;
|
||||
};
|
||||
|
@ -29,9 +29,9 @@
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include "socket.h"
|
||||
#include "core/log/logger.h"
|
||||
#include <assert.h>
|
||||
#include <sys/types.h>
|
||||
#include "core/log/logger.h"
|
||||
#ifdef _WIN32
|
||||
#include <ws2tcpip.h>
|
||||
#else
|
||||
@ -39,7 +39,6 @@
|
||||
#include <sys/socket.h>
|
||||
#endif
|
||||
|
||||
|
||||
bool Socket::isSelfConnect(int sockfd) {
|
||||
struct sockaddr_in6 localaddr = getLocalAddr(sockfd);
|
||||
struct sockaddr_in6 peeraddr = getPeerAddr(sockfd);
|
||||
@ -75,6 +74,7 @@ void Socket::bindAddress(const InetAddress &localaddr) {
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void Socket::listen() {
|
||||
assert(sockFd_ > 0);
|
||||
int ret = ::listen(sockFd_, SOMAXCONN);
|
||||
@ -83,6 +83,7 @@ void Socket::listen() {
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
int Socket::accept(InetAddress *peeraddr) {
|
||||
struct sockaddr_in6 addr6;
|
||||
memset(&addr6, 0, sizeof(addr6));
|
||||
@ -102,6 +103,7 @@ int Socket::accept(InetAddress *peeraddr) {
|
||||
}
|
||||
return connfd;
|
||||
}
|
||||
|
||||
void Socket::closeWrite() {
|
||||
#ifndef _WIN32
|
||||
if (::shutdown(sockFd_, SHUT_WR) < 0)
|
||||
@ -112,6 +114,7 @@ void Socket::closeWrite() {
|
||||
LOG_SYSERR << "sockets::shutdownWrite";
|
||||
}
|
||||
}
|
||||
|
||||
int Socket::read(char *buffer, uint64_t len) {
|
||||
#ifndef _WIN32
|
||||
return ::read(sockFd_, buffer, len);
|
||||
|
@ -32,8 +32,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/net/inet_address.h"
|
||||
#include "core/log/logger.h"
|
||||
#include "core/net/inet_address.h"
|
||||
#include <string>
|
||||
#ifndef _WIN32
|
||||
#include <unistd.h>
|
||||
@ -41,7 +41,6 @@
|
||||
#include <fcntl.h>
|
||||
|
||||
class Socket {
|
||||
|
||||
protected:
|
||||
Socket(const Socket &) = delete;
|
||||
Socket &operator=(const Socket &) = delete;
|
||||
@ -63,6 +62,7 @@ public:
|
||||
LOG_SYSERR << "sockets::createNonblockingOrDie";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
LOG_TRACE << "sock=" << sock;
|
||||
return sock;
|
||||
}
|
||||
@ -85,15 +85,9 @@ public:
|
||||
|
||||
static int connect(int sockfd, const InetAddress &addr) {
|
||||
if (addr.isIpV6())
|
||||
return ::connect(sockfd,
|
||||
addr.getSockAddr(),
|
||||
static_cast<socklen_t>(
|
||||
sizeof(struct sockaddr_in6)));
|
||||
return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in6)));
|
||||
else
|
||||
return ::connect(sockfd,
|
||||
addr.getSockAddr(),
|
||||
static_cast<socklen_t>(
|
||||
sizeof(struct sockaddr_in)));
|
||||
return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in)));
|
||||
}
|
||||
|
||||
static bool isSelfConnect(int sockfd);
|
||||
@ -101,7 +95,9 @@ public:
|
||||
explicit Socket(int sockfd) :
|
||||
sockFd_(sockfd) {
|
||||
}
|
||||
|
||||
~Socket();
|
||||
|
||||
/// abort if address in use
|
||||
void bindAddress(const InetAddress &localaddr);
|
||||
/// abort if address in use
|
||||
@ -109,9 +105,11 @@ public:
|
||||
int accept(InetAddress *peeraddr);
|
||||
void closeWrite();
|
||||
int read(char *buffer, uint64_t len);
|
||||
|
||||
int fd() {
|
||||
return sockFd_;
|
||||
}
|
||||
|
||||
static struct sockaddr_in6 getLocalAddr(int sockfd);
|
||||
static struct sockaddr_in6 getPeerAddr(int sockfd);
|
||||
|
||||
|
@ -54,9 +54,7 @@ TcpClient::IgnoreSigPipe TcpClient::initObj;
|
||||
#endif
|
||||
|
||||
static void defaultConnectionCallback(const TcpConnectionPtr &conn) {
|
||||
LOG_TRACE << conn->localAddr().toIpPort() << " -> "
|
||||
<< conn->peerAddr().toIpPort() << " is "
|
||||
<< (conn->connected() ? "UP" : "DOWN");
|
||||
LOG_TRACE << conn->localAddr().toIpPort() << " -> " << conn->peerAddr().toIpPort() << " is " << (conn->connected() ? "UP" : "DOWN");
|
||||
// do not call conn->forceClose(), because some users want to register
|
||||
// message callback only.
|
||||
}
|
||||
@ -65,9 +63,7 @@ static void defaultMessageCallback(const TcpConnectionPtr &, MsgBuffer *buf) {
|
||||
buf->retrieveAll();
|
||||
}
|
||||
|
||||
TcpClient::TcpClient(EventLoop *loop,
|
||||
const InetAddress &serverAddr,
|
||||
const std::string &nameArg) :
|
||||
TcpClient::TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg) :
|
||||
loop_(loop),
|
||||
connector_(new Connector(loop, serverAddr, false)),
|
||||
name_(nameArg),
|
||||
@ -75,23 +71,27 @@ TcpClient::TcpClient(EventLoop *loop,
|
||||
messageCallback_(defaultMessageCallback),
|
||||
retry_(false),
|
||||
connect_(true) {
|
||||
connector_->setNewConnectionCallback(
|
||||
std::bind(&TcpClient::newConnection, this, _1));
|
||||
|
||||
connector_->setNewConnectionCallback(std::bind(&TcpClient::newConnection, this, _1));
|
||||
|
||||
connector_->setErrorCallback([this]() {
|
||||
if (connectionErrorCallback_) {
|
||||
connectionErrorCallback_();
|
||||
}
|
||||
});
|
||||
|
||||
LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector ";
|
||||
}
|
||||
|
||||
TcpClient::~TcpClient() {
|
||||
LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector ";
|
||||
|
||||
TcpConnectionImplPtr conn;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
conn = std::dynamic_pointer_cast<TcpConnectionImpl>(connection_);
|
||||
}
|
||||
|
||||
if (conn) {
|
||||
assert(loop_ == conn->getLoop());
|
||||
// TODO: not 100% safe, if we are in different thread
|
||||
@ -104,6 +104,7 @@ TcpClient::~TcpClient() {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
conn->forceClose();
|
||||
} else {
|
||||
/// TODO need test in this condition
|
||||
@ -142,39 +143,34 @@ void TcpClient::newConnection(int sockfd) {
|
||||
// TODO poll with zero timeout to double confirm the new connection
|
||||
// TODO use make_shared if necessary
|
||||
std::shared_ptr<TcpConnectionImpl> conn;
|
||||
|
||||
if (sslCtxPtr_) {
|
||||
#ifdef USE_OPENSSL
|
||||
conn = std::make_shared<TcpConnectionImpl>(loop_,
|
||||
sockfd,
|
||||
localAddr,
|
||||
peerAddr,
|
||||
sslCtxPtr_,
|
||||
false,
|
||||
validateCert_,
|
||||
SSLHostName_);
|
||||
conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr, sslCtxPtr_, false, validateCert_, SSLHostName_);
|
||||
#else
|
||||
LOG_FATAL << "OpenSSL is not found in your system!";
|
||||
abort();
|
||||
#endif
|
||||
} else {
|
||||
conn = std::make_shared<TcpConnectionImpl>(loop_,
|
||||
sockfd,
|
||||
localAddr,
|
||||
peerAddr);
|
||||
conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr);
|
||||
}
|
||||
|
||||
conn->setConnectionCallback(connectionCallback_);
|
||||
conn->setRecvMsgCallback(messageCallback_);
|
||||
conn->setWriteCompleteCallback(writeCompleteCallback_);
|
||||
conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1));
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
connection_ = conn;
|
||||
}
|
||||
|
||||
conn->setSSLErrorCallback([this](SSLError err) {
|
||||
if (sslErrorCallback_) {
|
||||
sslErrorCallback_(err);
|
||||
}
|
||||
});
|
||||
|
||||
conn->connectEstablished();
|
||||
}
|
||||
|
||||
@ -188,30 +184,23 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn) {
|
||||
connection_.reset();
|
||||
}
|
||||
|
||||
loop_->queueInLoop(
|
||||
std::bind(&TcpConnectionImpl::connectDestroyed,
|
||||
std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
|
||||
loop_->queueInLoop(std::bind(&TcpConnectionImpl::connectDestroyed, std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
|
||||
|
||||
if (retry_ && connect_) {
|
||||
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to "
|
||||
<< connector_->serverAddress().toIpPort();
|
||||
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " << connector_->serverAddress().toIpPort();
|
||||
connector_->restart();
|
||||
}
|
||||
}
|
||||
|
||||
void TcpClient::enableSSL(
|
||||
bool useOldTLS,
|
||||
bool validateCert,
|
||||
std::string hostname,
|
||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
||||
void TcpClient::enableSSL(bool useOldTLS, bool validateCert, std::string hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
||||
#ifdef USE_OPENSSL
|
||||
/* Create a new OpenSSL context */
|
||||
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds);
|
||||
|
||||
validateCert_ = validateCert;
|
||||
|
||||
if (!hostname.empty()) {
|
||||
std::transform(hostname.begin(),
|
||||
hostname.end(),
|
||||
hostname.begin(),
|
||||
tolower);
|
||||
std::transform(hostname.begin(), hostname.end(), hostname.begin(), tolower);
|
||||
SSLHostName_ = std::move(hostname);
|
||||
}
|
||||
|
||||
|
@ -42,10 +42,7 @@
|
||||
class Connector;
|
||||
using ConnectorPtr = std::shared_ptr<Connector>;
|
||||
class SSLContext;
|
||||
/**
|
||||
* @brief This class represents a TCP client.
|
||||
*
|
||||
*/
|
||||
|
||||
class TcpClient {
|
||||
protected:
|
||||
TcpClient(const TcpClient &) = delete;
|
||||
@ -55,88 +52,35 @@ protected:
|
||||
TcpClient &operator=(TcpClient &&) noexcept(true) = default;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new TCP client instance.
|
||||
*
|
||||
* @param loop The event loop in which the client runs.
|
||||
* @param serverAddr The address of the server.
|
||||
* @param nameArg The name of the client.
|
||||
*/
|
||||
TcpClient(EventLoop *loop,
|
||||
const InetAddress &serverAddr,
|
||||
const std::string &nameArg);
|
||||
TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg);
|
||||
~TcpClient();
|
||||
|
||||
/**
|
||||
* @brief Connect to the server.
|
||||
*
|
||||
*/
|
||||
void connect();
|
||||
|
||||
/**
|
||||
* @brief Disconnect from the server.
|
||||
*
|
||||
*/
|
||||
void disconnect();
|
||||
|
||||
/**
|
||||
* @brief Stop connecting to the server.
|
||||
*
|
||||
*/
|
||||
void stop();
|
||||
|
||||
/**
|
||||
* @brief Get the TCP connection to the server.
|
||||
*
|
||||
* @return TcpConnectionPtr
|
||||
*/
|
||||
TcpConnectionPtr connection() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return connection_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the event loop.
|
||||
*
|
||||
* @return EventLoop*
|
||||
*/
|
||||
EventLoop *getLoop() const {
|
||||
return loop_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check whether the client re-connect to the server.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
bool retry() const {
|
||||
return retry_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Enable retrying.
|
||||
*
|
||||
*/
|
||||
void enableRetry() {
|
||||
retry_ = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the name of the client.
|
||||
*
|
||||
* @return const std::string&
|
||||
*/
|
||||
const std::string &name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the connection callback.
|
||||
*
|
||||
* @param cb The callback is called when the connection to the server is
|
||||
* established or closed.
|
||||
*/
|
||||
void setConnectionCallback(const ConnectionCallback &cb) {
|
||||
connectionCallback_ = cb;
|
||||
}
|
||||
@ -144,37 +88,18 @@ public:
|
||||
connectionCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the connection error callback.
|
||||
*
|
||||
* @param cb The callback is called when an error occurs during connecting
|
||||
* to the server.
|
||||
*/
|
||||
void setConnectionErrorCallback(const ConnectionErrorCallback &cb) {
|
||||
connectionErrorCallback_ = cb;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the message callback.
|
||||
*
|
||||
* @param cb The callback is called when some data is received from the
|
||||
* server.
|
||||
*/
|
||||
void setMessageCallback(const RecvMessageCallback &cb) {
|
||||
messageCallback_ = cb;
|
||||
}
|
||||
void setMessageCallback(RecvMessageCallback &&cb) {
|
||||
messageCallback_ = std::move(cb);
|
||||
}
|
||||
/// Set write complete callback.
|
||||
/// Not thread safe.
|
||||
|
||||
/**
|
||||
* @brief Set the write complete callback.
|
||||
*
|
||||
* @param cb The callback is called when data to send is written to the
|
||||
* socket.
|
||||
*/
|
||||
/// Not thread safe.
|
||||
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
||||
writeCompleteCallback_ = cb;
|
||||
}
|
||||
@ -182,10 +107,6 @@ public:
|
||||
writeCompleteCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the callback for errors of SSL
|
||||
* @param cb The callback is called when an SSL error occurs.
|
||||
*/
|
||||
void setSSLErrorCallback(const SSLErrorCallback &cb) {
|
||||
sslErrorCallback_ = cb;
|
||||
}
|
||||
@ -193,24 +114,7 @@ public:
|
||||
sslErrorCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Enable SSL encryption.
|
||||
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
|
||||
* client.
|
||||
* @param validateCert If true, we try to validate if the peer's SSL cert
|
||||
* is valid.
|
||||
* @param hostname The server hostname for SNI. If it is empty, the SNI is
|
||||
* not used.
|
||||
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
|
||||
* OpenSSL.
|
||||
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
|
||||
* 2020. And it's a good practice to only use TLS 1.2 and above.
|
||||
*/
|
||||
void enableSSL(bool useOldTLS = false,
|
||||
bool validateCert = true,
|
||||
std::string hostname = "",
|
||||
const std::vector<std::pair<std::string, std::string> >
|
||||
&sslConfCmds = {});
|
||||
void enableSSL(bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
||||
|
||||
private:
|
||||
/// Not thread safe, but in loop
|
||||
@ -234,6 +138,7 @@ private:
|
||||
std::shared_ptr<SSLContext> sslCtxPtr_;
|
||||
bool validateCert_{ false };
|
||||
std::string SSLHostName_;
|
||||
|
||||
#ifndef _WIN32
|
||||
class IgnoreSigPipe {
|
||||
public:
|
||||
|
@ -39,26 +39,13 @@
|
||||
#include <string>
|
||||
|
||||
class SSLContext;
|
||||
std::shared_ptr<SSLContext> newSSLServerContext(
|
||||
const std::string &certPath,
|
||||
const std::string &keyPath,
|
||||
bool useOldTLS = false,
|
||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
||||
/**
|
||||
* @brief This class represents a TCP connection.
|
||||
*
|
||||
*/
|
||||
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
||||
|
||||
class TcpConnection {
|
||||
public:
|
||||
TcpConnection() = default;
|
||||
virtual ~TcpConnection(){};
|
||||
|
||||
/**
|
||||
* @brief Send some data to the peer.
|
||||
*
|
||||
* @param msg
|
||||
* @param len
|
||||
*/
|
||||
virtual void send(const char *msg, size_t len) = 0;
|
||||
virtual void send(const void *msg, size_t len) = 0;
|
||||
virtual void send(const std::string &msg) = 0;
|
||||
@ -68,95 +55,27 @@ public:
|
||||
virtual void send(const std::shared_ptr<std::string> &msgPtr) = 0;
|
||||
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) = 0;
|
||||
|
||||
/**
|
||||
* @brief Send a file to the peer.
|
||||
*
|
||||
* @param fileName
|
||||
* @param offset
|
||||
* @param length
|
||||
*/
|
||||
virtual void sendFile(const char *fileName,
|
||||
size_t offset = 0,
|
||||
size_t length = 0) = 0;
|
||||
virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) = 0;
|
||||
|
||||
/**
|
||||
* @brief Get the local address of the connection.
|
||||
*
|
||||
* @return const InetAddress&
|
||||
*/
|
||||
virtual const InetAddress &localAddr() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Get the remote address of the connection.
|
||||
*
|
||||
* @return const InetAddress&
|
||||
*/
|
||||
virtual const InetAddress &peerAddr() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Return true if the connection is established.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
virtual bool connected() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Return false if the connection is established.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
virtual bool disconnected() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Get the buffer in which the received data stored.
|
||||
*
|
||||
* @return MsgBuffer*
|
||||
*/
|
||||
virtual MsgBuffer *getRecvBuffer() = 0;
|
||||
|
||||
/**
|
||||
* @brief Set the high water mark callback
|
||||
*
|
||||
* @param cb The callback is called when the data in sending buffer is
|
||||
* larger than the water mark.
|
||||
* @param markLen The water mark in bytes.
|
||||
*/
|
||||
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
|
||||
size_t markLen) = 0;
|
||||
|
||||
/**
|
||||
* @brief Set the TCP_NODELAY option to the socket.
|
||||
*
|
||||
* @param on
|
||||
*/
|
||||
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, size_t markLen) = 0;
|
||||
|
||||
virtual void setTcpNoDelay(bool on) = 0;
|
||||
|
||||
/**
|
||||
* @brief Shutdown the connection.
|
||||
* @note This method only closes the writing direction.
|
||||
*/
|
||||
virtual void shutdown() = 0;
|
||||
|
||||
/**
|
||||
* @brief Close the connection forcefully.
|
||||
*
|
||||
*/
|
||||
virtual void forceClose() = 0;
|
||||
|
||||
/**
|
||||
* @brief Get the event loop in which the connection I/O is handled.
|
||||
*
|
||||
* @return EventLoop*
|
||||
*/
|
||||
|
||||
virtual EventLoop *getLoop() = 0;
|
||||
|
||||
/**
|
||||
* @brief Set the custom data on the connection.
|
||||
*
|
||||
* @param context
|
||||
*/
|
||||
void setContext(const std::shared_ptr<void> &context) {
|
||||
contextPtr_ = context;
|
||||
}
|
||||
@ -164,98 +83,30 @@ public:
|
||||
contextPtr_ = std::move(context);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the custom data from the connection.
|
||||
*
|
||||
* @tparam T
|
||||
* @return std::shared_ptr<T>
|
||||
*/
|
||||
template <typename T>
|
||||
std::shared_ptr<T> getContext() const {
|
||||
return std::static_pointer_cast<T>(contextPtr_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return true if the custom data is set by user.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
bool hasContext() const {
|
||||
return (bool)contextPtr_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Clear the custom data.
|
||||
*
|
||||
*/
|
||||
void clearContext() {
|
||||
contextPtr_.reset();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Call this method to avoid being kicked off by TcpServer, refer to
|
||||
* the kickoffIdleConnections method in the TcpServer class.
|
||||
*
|
||||
*/
|
||||
virtual void keepAlive() = 0;
|
||||
|
||||
/**
|
||||
* @brief Return true if the keepAlive() method is called.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
virtual bool isKeepAlive() = 0;
|
||||
|
||||
/**
|
||||
* @brief Return the number of bytes sent
|
||||
*
|
||||
* @return size_t
|
||||
*/
|
||||
virtual size_t bytesSent() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Return the number of bytes received.
|
||||
*
|
||||
* @return size_t
|
||||
*/
|
||||
virtual size_t bytesReceived() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Check whether the connection is SSL encrypted.
|
||||
*
|
||||
* @return true
|
||||
* @return false
|
||||
*/
|
||||
virtual bool isSSLConnection() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Start the SSL encryption on the connection (as a client).
|
||||
*
|
||||
* @param callback The callback is called when the SSL connection is
|
||||
* established.
|
||||
* @param hostname The server hostname for SNI. If it is empty, the SNI is
|
||||
* not used.
|
||||
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
|
||||
* OpenSSL.
|
||||
*/
|
||||
virtual void startClientEncryption(
|
||||
std::function<void()> callback,
|
||||
bool useOldTLS = false,
|
||||
bool validateCert = true,
|
||||
std::string hostname = "",
|
||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
|
||||
virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
|
||||
|
||||
/**
|
||||
* @brief Start the SSL encryption on the connection (as a server).
|
||||
*
|
||||
* @param ctx The SSL context.
|
||||
* @param callback The callback is called when the SSL connection is
|
||||
* established.
|
||||
*/
|
||||
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
|
||||
std::function<void()> callback) = 0;
|
||||
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) = 0;
|
||||
|
||||
protected:
|
||||
bool validateCert_ = false;
|
||||
|
@ -28,20 +28,16 @@
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include "core/loops/acceptor.h"
|
||||
#include "core/net/connections/tcp_connection_impl.h"
|
||||
#include "core/net/tcp_server.h"
|
||||
#include "core/log/logger.h"
|
||||
#include "core/loops/acceptor.h"
|
||||
#include "core/net/connections/tcp_connection_impl.h"
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
using namespace std::placeholders;
|
||||
|
||||
TcpServer::TcpServer(EventLoop *loop,
|
||||
const InetAddress &address,
|
||||
const std::string &name,
|
||||
bool reUseAddr,
|
||||
bool reUsePort) :
|
||||
TcpServer::TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr, bool reUsePort) :
|
||||
loop_(loop),
|
||||
acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)),
|
||||
serverName_(name),
|
||||
@ -50,8 +46,8 @@ TcpServer::TcpServer(EventLoop *loop,
|
||||
<< " bytes]";
|
||||
buffer->retrieveAll();
|
||||
}) {
|
||||
acceptorPtr_->setNewConnectionCallback(
|
||||
std::bind(&TcpServer::newConnection, this, _1, _2));
|
||||
|
||||
acceptorPtr_->setNewConnectionCallback(std::bind(&TcpServer::newConnection, this, _1, _2));
|
||||
}
|
||||
|
||||
TcpServer::~TcpServer() {
|
||||
@ -69,14 +65,20 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
|
||||
// LOG_TRACE<<"vector size:"<<str.size();
|
||||
// size_t n=write(sockfd,&str[0],str.size());
|
||||
// LOG_TRACE<<"write "<<n<<" bytes";
|
||||
|
||||
loop_->assertInLoopThread();
|
||||
|
||||
EventLoop *ioLoop = NULL;
|
||||
if (loopPoolPtr_ && loopPoolPtr_->size() > 0) {
|
||||
ioLoop = loopPoolPtr_->getNextLoop();
|
||||
}
|
||||
if (ioLoop == NULL)
|
||||
|
||||
if (ioLoop == NULL) {
|
||||
ioLoop = loop_;
|
||||
}
|
||||
|
||||
std::shared_ptr<TcpConnectionImpl> newPtr;
|
||||
|
||||
if (sslCtxPtr_) {
|
||||
#ifdef USE_OPENSSL
|
||||
newPtr = std::make_shared<TcpConnectionImpl>(
|
||||
@ -90,14 +92,14 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
|
||||
abort();
|
||||
#endif
|
||||
} else {
|
||||
newPtr = std::make_shared<TcpConnectionImpl>(
|
||||
ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
|
||||
newPtr = std::make_shared<TcpConnectionImpl>(ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
|
||||
}
|
||||
|
||||
if (idleTimeout_ > 0) {
|
||||
assert(timingWheelMap_[ioLoop]);
|
||||
newPtr->enableKickingOff(idleTimeout_, timingWheelMap_[ioLoop]);
|
||||
}
|
||||
|
||||
newPtr->setRecvMsgCallback(recvMessageCallback_);
|
||||
|
||||
newPtr->setConnectionCallback(
|
||||
@ -105,11 +107,13 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
|
||||
if (connectionCallback_)
|
||||
connectionCallback_(connectionPtr);
|
||||
});
|
||||
|
||||
newPtr->setWriteCompleteCallback(
|
||||
[this](const TcpConnectionPtr &connectionPtr) {
|
||||
if (writeCompleteCallback_)
|
||||
writeCompleteCallback_(connectionPtr);
|
||||
});
|
||||
|
||||
newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1));
|
||||
connSet_.insert(newPtr);
|
||||
newPtr->connectEstablished();
|
||||
@ -143,6 +147,7 @@ void TcpServer::start() {
|
||||
acceptorPtr_->listen();
|
||||
});
|
||||
}
|
||||
|
||||
void TcpServer::stop() {
|
||||
loop_->runInLoop([this]() { acceptorPtr_.reset(); });
|
||||
for (auto connection : connSet_) {
|
||||
@ -159,6 +164,7 @@ void TcpServer::stop() {
|
||||
f.get();
|
||||
}
|
||||
}
|
||||
|
||||
void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) {
|
||||
LOG_TRACE << "connectionClosed";
|
||||
// loop_->assertInLoopThread();
|
||||
@ -179,11 +185,7 @@ const InetAddress &TcpServer::address() const {
|
||||
return acceptorPtr_->addr();
|
||||
}
|
||||
|
||||
void TcpServer::enableSSL(
|
||||
const std::string &certPath,
|
||||
const std::string &keyPath,
|
||||
bool useOldTLS,
|
||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
||||
void TcpServer::enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
||||
#ifdef USE_OPENSSL
|
||||
/* Create a new OpenSSL context */
|
||||
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds);
|
||||
|
@ -43,10 +43,7 @@
|
||||
|
||||
class Acceptor;
|
||||
class SSLContext;
|
||||
/**
|
||||
* @brief This class represents a TCP server.
|
||||
*
|
||||
*/
|
||||
|
||||
class TcpServer {
|
||||
protected:
|
||||
TcpServer(const TcpServer &) = delete;
|
||||
@ -56,53 +53,18 @@ protected:
|
||||
TcpServer &operator=(TcpServer &&) noexcept(true) = default;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new TCP server instance.
|
||||
*
|
||||
* @param loop The event loop in which the acceptor of the server is
|
||||
* handled.
|
||||
* @param address The address of the server.
|
||||
* @param name The name of the server.
|
||||
* @param reUseAddr The SO_REUSEADDR option.
|
||||
* @param reUsePort The SO_REUSEPORT option.
|
||||
*/
|
||||
TcpServer(EventLoop *loop,
|
||||
const InetAddress &address,
|
||||
const std::string &name,
|
||||
bool reUseAddr = true,
|
||||
bool reUsePort = true);
|
||||
TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr = true, bool reUsePort = true);
|
||||
~TcpServer();
|
||||
|
||||
/**
|
||||
* @brief Start the server.
|
||||
*
|
||||
*/
|
||||
void start();
|
||||
|
||||
/**
|
||||
* @brief Stop the server.
|
||||
*
|
||||
*/
|
||||
void stop();
|
||||
|
||||
/**
|
||||
* @brief Set the number of event loops in which the I/O of connections to
|
||||
* the server is handled.
|
||||
*
|
||||
* @param num
|
||||
*/
|
||||
void setIoLoopNum(size_t num) {
|
||||
assert(!started_);
|
||||
loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num);
|
||||
loopPoolPtr_->start();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the event loops pool in which the I/O of connections to
|
||||
* the server is handled.
|
||||
*
|
||||
* @param pool
|
||||
*/
|
||||
void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) {
|
||||
assert(pool->size() > 0);
|
||||
assert(!started_);
|
||||
@ -110,12 +72,6 @@ public:
|
||||
loopPoolPtr_->start();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the message callback.
|
||||
*
|
||||
* @param cb The callback is called when some data is received on a
|
||||
* connection to the server.
|
||||
*/
|
||||
void setRecvMessageCallback(const RecvMessageCallback &cb) {
|
||||
recvMessageCallback_ = cb;
|
||||
}
|
||||
@ -123,12 +79,6 @@ public:
|
||||
recvMessageCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the connection callback.
|
||||
*
|
||||
* @param cb The callback is called when a connection is established or
|
||||
* closed.
|
||||
*/
|
||||
void setConnectionCallback(const ConnectionCallback &cb) {
|
||||
connectionCallback_ = cb;
|
||||
}
|
||||
@ -136,12 +86,6 @@ public:
|
||||
connectionCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set the write complete callback.
|
||||
*
|
||||
* @param cb The callback is called when data to send is written to the
|
||||
* socket of a connection.
|
||||
*/
|
||||
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
||||
writeCompleteCallback_ = cb;
|
||||
}
|
||||
@ -149,53 +93,21 @@ public:
|
||||
writeCompleteCallback_ = std::move(cb);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the name of the server.
|
||||
*
|
||||
* @return const std::string&
|
||||
*/
|
||||
const std::string &name() const {
|
||||
return serverName_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the IP and port string of the server.
|
||||
*
|
||||
* @return const std::string
|
||||
*/
|
||||
const std::string ipPort() const;
|
||||
|
||||
/**
|
||||
* @brief Get the address of the server.
|
||||
*
|
||||
* @return const InetAddress&
|
||||
*/
|
||||
const InetAddress &address() const;
|
||||
|
||||
/**
|
||||
* @brief Get the event loop of the server.
|
||||
*
|
||||
* @return EventLoop*
|
||||
*/
|
||||
EventLoop *getLoop() const {
|
||||
return loop_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the I/O event loops of the server.
|
||||
*
|
||||
* @return std::vector<EventLoop *>
|
||||
*/
|
||||
std::vector<EventLoop *> getIoLoops() const {
|
||||
return loopPoolPtr_->getLoops();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief An idle connection is a connection that has no read or write, kick
|
||||
* off it after timeout seconds.
|
||||
*
|
||||
* @param timeout
|
||||
*/
|
||||
void kickoffIdleConnections(size_t timeout) {
|
||||
loop_->runInLoop([this, timeout]() {
|
||||
assert(!started_);
|
||||
@ -203,23 +115,12 @@ public:
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Enable SSL encryption.
|
||||
*
|
||||
* @param certPath The path of the certificate file.
|
||||
* @param keyPath The path of the private key file.
|
||||
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
|
||||
* server.
|
||||
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
|
||||
* OpenSSL.
|
||||
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
|
||||
* 2020. And it's a good practice to only use TLS 1.2 and above.
|
||||
*/
|
||||
void enableSSL(const std::string &certPath,
|
||||
const std::string &keyPath,
|
||||
bool useOldTLS = false,
|
||||
const std::vector<std::pair<std::string, std::string> >
|
||||
&sslConfCmds = {});
|
||||
// certPath The path of the certificate file.
|
||||
// keyPath The path of the private key file.
|
||||
// useOldTLS If true, the TLS 1.0 and 1.1 are supported by the server.
|
||||
// sslConfCmds The commands used to call the SSL_CONF_cmd function in OpenSSL.
|
||||
// Note: It's well known that TLS 1.0 and 1.1 are not considered secure in 2020. And it's a good practice to only use TLS 1.2 and above.
|
||||
void enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
||||
|
||||
private:
|
||||
EventLoop *loop_;
|
||||
@ -236,6 +137,7 @@ private:
|
||||
std::map<EventLoop *, std::shared_ptr<TimingWheel> > timingWheelMap_;
|
||||
void connectionClosed(const TcpConnectionPtr &connectionPtr);
|
||||
std::shared_ptr<EventLoopThreadPool> loopPoolPtr_;
|
||||
|
||||
#ifndef _WIN32
|
||||
class IgnoreSigPipe {
|
||||
public:
|
||||
@ -247,6 +149,7 @@ private:
|
||||
|
||||
IgnoreSigPipe initObj;
|
||||
#endif
|
||||
|
||||
bool started_{ false };
|
||||
|
||||
// OpenSSL SSL context Object;
|
||||
|
Loading…
Reference in New Issue
Block a user