mirror of
https://github.com/Relintai/rcpp_framework.git
synced 2025-05-06 17:51:36 +02:00
initial codestyle cleanup in the net folder.
This commit is contained in:
parent
8992cf49b3
commit
38a28ee9ce
File diff suppressed because it is too large
Load Diff
@ -30,54 +30,45 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "core/net/tcp_connection.h"
|
|
||||||
#include "core/loops/timing_wheel.h"
|
#include "core/loops/timing_wheel.h"
|
||||||
|
#include "core/net/tcp_connection.h"
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#endif
|
#endif
|
||||||
#include <thread>
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
enum class SSLStatus
|
|
||||||
{
|
enum class SSLStatus {
|
||||||
Handshaking,
|
Handshaking,
|
||||||
Connecting,
|
Connecting,
|
||||||
Connected,
|
Connected,
|
||||||
DisConnecting,
|
DisConnecting,
|
||||||
DisConnected
|
DisConnected
|
||||||
};
|
};
|
||||||
|
|
||||||
class SSLContext;
|
class SSLContext;
|
||||||
class SSLConn;
|
class SSLConn;
|
||||||
|
|
||||||
std::shared_ptr<SSLContext> newSSLContext(
|
std::shared_ptr<SSLContext> newSSLContext(bool useOldTLS, bool validateCert, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
|
||||||
bool useOldTLS,
|
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
|
||||||
bool validateCert,
|
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx, const std::string &certPath, const std::string &keyPath);
|
||||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
|
||||||
std::shared_ptr<SSLContext> newSSLServerContext(
|
|
||||||
const std::string &certPath,
|
|
||||||
const std::string &keyPath,
|
|
||||||
bool useOldTLS,
|
|
||||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
|
||||||
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
|
|
||||||
// const std::string &certPath,
|
|
||||||
// const std::string &keyPath);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
class Channel;
|
class Channel;
|
||||||
class Socket;
|
class Socket;
|
||||||
class TcpServer;
|
class TcpServer;
|
||||||
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
|
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
|
||||||
|
|
||||||
class TcpConnectionImpl : public TcpConnection,
|
class TcpConnectionImpl : public TcpConnection, public std::enable_shared_from_this<TcpConnectionImpl> {
|
||||||
public std::enable_shared_from_this<TcpConnectionImpl>
|
friend class TcpServer;
|
||||||
{
|
friend class TcpClient;
|
||||||
friend class TcpServer;
|
|
||||||
friend class TcpClient;
|
friend void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
|
||||||
friend void removeConnection(EventLoop *loop,
|
|
||||||
const TcpConnectionPtr &conn);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
TcpConnectionImpl(const TcpConnectionImpl &) = delete;
|
TcpConnectionImpl(const TcpConnectionImpl &) = delete;
|
||||||
@ -86,278 +77,247 @@ protected:
|
|||||||
TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default;
|
TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default;
|
||||||
TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default;
|
TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
class KickoffEntry
|
class KickoffEntry {
|
||||||
{
|
public:
|
||||||
public:
|
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn) :
|
||||||
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn)
|
conn_(conn) {
|
||||||
: conn_(conn)
|
}
|
||||||
{
|
|
||||||
}
|
|
||||||
void reset()
|
|
||||||
{
|
|
||||||
conn_.reset();
|
|
||||||
}
|
|
||||||
~KickoffEntry()
|
|
||||||
{
|
|
||||||
auto conn = conn_.lock();
|
|
||||||
if (conn)
|
|
||||||
{
|
|
||||||
conn->forceClose();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
void reset() {
|
||||||
std::weak_ptr<TcpConnection> conn_;
|
conn_.reset();
|
||||||
};
|
}
|
||||||
|
|
||||||
TcpConnectionImpl(EventLoop *loop,
|
~KickoffEntry() {
|
||||||
int socketfd,
|
auto conn = conn_.lock();
|
||||||
const InetAddress &localAddr,
|
if (conn) {
|
||||||
const InetAddress &peerAddr);
|
conn->forceClose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::weak_ptr<TcpConnection> conn_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr);
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
TcpConnectionImpl(EventLoop *loop,
|
TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, const std::shared_ptr<SSLContext> &ctxPtr, bool isServer = true, bool validateCert = true, const std::string &hostname = "");
|
||||||
int socketfd,
|
|
||||||
const InetAddress &localAddr,
|
|
||||||
const InetAddress &peerAddr,
|
|
||||||
const std::shared_ptr<SSLContext> &ctxPtr,
|
|
||||||
bool isServer = true,
|
|
||||||
bool validateCert = true,
|
|
||||||
const std::string &hostname = "");
|
|
||||||
#endif
|
#endif
|
||||||
virtual ~TcpConnectionImpl();
|
|
||||||
virtual void send(const char *msg, size_t len) override;
|
|
||||||
virtual void send(const void *msg, size_t len) override;
|
|
||||||
virtual void send(const std::string &msg) override;
|
|
||||||
virtual void send(std::string &&msg) override;
|
|
||||||
virtual void send(const MsgBuffer &buffer) override;
|
|
||||||
virtual void send(MsgBuffer &&buffer) override;
|
|
||||||
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
|
|
||||||
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
|
|
||||||
virtual void sendFile(const char *fileName,
|
|
||||||
size_t offset = 0,
|
|
||||||
size_t length = 0) override;
|
|
||||||
|
|
||||||
virtual const InetAddress &localAddr() const override
|
virtual ~TcpConnectionImpl();
|
||||||
{
|
|
||||||
return localAddr_;
|
|
||||||
}
|
|
||||||
virtual const InetAddress &peerAddr() const override
|
|
||||||
{
|
|
||||||
return peerAddr_;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool connected() const override
|
virtual void send(const char *msg, size_t len) override;
|
||||||
{
|
virtual void send(const void *msg, size_t len) override;
|
||||||
return status_ == ConnStatus::Connected;
|
virtual void send(const std::string &msg) override;
|
||||||
}
|
virtual void send(std::string &&msg) override;
|
||||||
virtual bool disconnected() const override
|
virtual void send(const MsgBuffer &buffer) override;
|
||||||
{
|
virtual void send(MsgBuffer &&buffer) override;
|
||||||
return status_ == ConnStatus::Disconnected;
|
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
|
||||||
}
|
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
|
||||||
|
virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) override;
|
||||||
|
|
||||||
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
|
virtual const InetAddress &localAddr() const override {
|
||||||
virtual MsgBuffer *getRecvBuffer() override
|
return localAddr_;
|
||||||
{
|
}
|
||||||
return &readBuffer_;
|
|
||||||
}
|
|
||||||
// set callbacks
|
|
||||||
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
|
|
||||||
size_t markLen) override
|
|
||||||
{
|
|
||||||
highWaterMarkCallback_ = cb;
|
|
||||||
highWaterMarkLen_ = markLen;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void keepAlive() override
|
virtual const InetAddress &peerAddr() const override {
|
||||||
{
|
return peerAddr_;
|
||||||
idleTimeout_ = 0;
|
}
|
||||||
auto entry = kickoffEntry_.lock();
|
|
||||||
if (entry)
|
|
||||||
{
|
|
||||||
entry->reset();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
virtual bool isKeepAlive() override
|
|
||||||
{
|
|
||||||
return idleTimeout_ == 0;
|
|
||||||
}
|
|
||||||
virtual void setTcpNoDelay(bool on) override;
|
|
||||||
virtual void shutdown() override;
|
|
||||||
virtual void forceClose() override;
|
|
||||||
virtual EventLoop *getLoop() override
|
|
||||||
{
|
|
||||||
return loop_;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual size_t bytesSent() const override
|
virtual bool connected() const override {
|
||||||
{
|
return status_ == ConnStatus::Connected;
|
||||||
return bytesSent_;
|
}
|
||||||
}
|
|
||||||
virtual size_t bytesReceived() const override
|
|
||||||
{
|
|
||||||
return bytesReceived_;
|
|
||||||
}
|
|
||||||
virtual void startClientEncryption(
|
|
||||||
std::function<void()> callback,
|
|
||||||
bool useOldTLS = false,
|
|
||||||
bool validateCert = true,
|
|
||||||
std::string hostname = "",
|
|
||||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
|
|
||||||
{}) override;
|
|
||||||
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
|
|
||||||
std::function<void()> callback) override;
|
|
||||||
virtual bool isSSLConnection() const override
|
|
||||||
{
|
|
||||||
return isEncrypted_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
virtual bool disconnected() const override {
|
||||||
/// Internal use only.
|
return status_ == ConnStatus::Disconnected;
|
||||||
|
}
|
||||||
|
|
||||||
std::weak_ptr<KickoffEntry> kickoffEntry_;
|
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
|
||||||
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
|
virtual MsgBuffer *getRecvBuffer() override {
|
||||||
size_t idleTimeout_{0};
|
return &readBuffer_;
|
||||||
Date lastTimingWheelUpdateTime_;
|
}
|
||||||
|
|
||||||
void enableKickingOff(size_t timeout,
|
// set callbacks
|
||||||
const std::shared_ptr<TimingWheel> &timingWheel)
|
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
|
||||||
{
|
size_t markLen) override {
|
||||||
assert(timingWheel);
|
highWaterMarkCallback_ = cb;
|
||||||
assert(timingWheel->getLoop() == loop_);
|
highWaterMarkLen_ = markLen;
|
||||||
assert(timeout > 0);
|
}
|
||||||
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
|
|
||||||
kickoffEntry_ = entry;
|
virtual void keepAlive() override {
|
||||||
timingWheelWeakPtr_ = timingWheel;
|
idleTimeout_ = 0;
|
||||||
idleTimeout_ = timeout;
|
auto entry = kickoffEntry_.lock();
|
||||||
timingWheel->insertEntry(timeout, entry);
|
if (entry) {
|
||||||
}
|
entry->reset();
|
||||||
void extendLife();
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool isKeepAlive() override {
|
||||||
|
return idleTimeout_ == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void setTcpNoDelay(bool on) override;
|
||||||
|
virtual void shutdown() override;
|
||||||
|
virtual void forceClose() override;
|
||||||
|
|
||||||
|
virtual EventLoop *getLoop() override {
|
||||||
|
return loop_;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual size_t bytesSent() const override {
|
||||||
|
return bytesSent_;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual size_t bytesReceived() const override {
|
||||||
|
return bytesReceived_;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) override;
|
||||||
|
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) override;
|
||||||
|
virtual bool isSSLConnection() const override {
|
||||||
|
return isEncrypted_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Internal use only.
|
||||||
|
|
||||||
|
std::weak_ptr<KickoffEntry> kickoffEntry_;
|
||||||
|
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
|
||||||
|
size_t idleTimeout_{ 0 };
|
||||||
|
Date lastTimingWheelUpdateTime_;
|
||||||
|
|
||||||
|
void enableKickingOff(size_t timeout,
|
||||||
|
const std::shared_ptr<TimingWheel> &timingWheel) {
|
||||||
|
assert(timingWheel);
|
||||||
|
assert(timingWheel->getLoop() == loop_);
|
||||||
|
assert(timeout > 0);
|
||||||
|
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
|
||||||
|
kickoffEntry_ = entry;
|
||||||
|
timingWheelWeakPtr_ = timingWheel;
|
||||||
|
idleTimeout_ = timeout;
|
||||||
|
timingWheel->insertEntry(timeout, entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
void extendLife();
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
|
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
|
||||||
#else
|
#else
|
||||||
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
|
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
|
||||||
#endif
|
#endif
|
||||||
void setRecvMsgCallback(const RecvMessageCallback &cb)
|
void setRecvMsgCallback(const RecvMessageCallback &cb) {
|
||||||
{
|
recvMsgCallback_ = cb;
|
||||||
recvMsgCallback_ = cb;
|
}
|
||||||
}
|
void setConnectionCallback(const ConnectionCallback &cb) {
|
||||||
void setConnectionCallback(const ConnectionCallback &cb)
|
connectionCallback_ = cb;
|
||||||
{
|
}
|
||||||
connectionCallback_ = cb;
|
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
||||||
}
|
writeCompleteCallback_ = cb;
|
||||||
void setWriteCompleteCallback(const WriteCompleteCallback &cb)
|
}
|
||||||
{
|
void setCloseCallback(const CloseCallback &cb) {
|
||||||
writeCompleteCallback_ = cb;
|
closeCallback_ = cb;
|
||||||
}
|
}
|
||||||
void setCloseCallback(const CloseCallback &cb)
|
void setSSLErrorCallback(const SSLErrorCallback &cb) {
|
||||||
{
|
sslErrorCallback_ = cb;
|
||||||
closeCallback_ = cb;
|
}
|
||||||
}
|
|
||||||
void setSSLErrorCallback(const SSLErrorCallback &cb)
|
|
||||||
{
|
|
||||||
sslErrorCallback_ = cb;
|
|
||||||
}
|
|
||||||
|
|
||||||
void connectDestroyed();
|
void connectDestroyed();
|
||||||
virtual void connectEstablished();
|
virtual void connectEstablished();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
struct BufferNode
|
struct BufferNode {
|
||||||
{
|
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
int sendFd_{-1};
|
int sendFd_{ -1 };
|
||||||
off_t offset_;
|
off_t offset_;
|
||||||
#else
|
#else
|
||||||
FILE *sendFp_{nullptr};
|
FILE *sendFp_{ nullptr };
|
||||||
long long offset_;
|
long long offset_;
|
||||||
#endif
|
#endif
|
||||||
ssize_t fileBytesToSend_;
|
ssize_t fileBytesToSend_;
|
||||||
std::shared_ptr<MsgBuffer> msgBuffer_;
|
std::shared_ptr<MsgBuffer> msgBuffer_;
|
||||||
~BufferNode()
|
|
||||||
{
|
~BufferNode() {
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
if (sendFd_ >= 0)
|
if (sendFd_ >= 0)
|
||||||
close(sendFd_);
|
close(sendFd_);
|
||||||
#else
|
#else
|
||||||
if (sendFp_)
|
if (sendFp_)
|
||||||
fclose(sendFp_);
|
fclose(sendFp_);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
using BufferNodePtr = std::shared_ptr<BufferNode>;
|
|
||||||
enum class ConnStatus
|
|
||||||
{
|
|
||||||
Disconnected,
|
|
||||||
Connecting,
|
|
||||||
Connected,
|
|
||||||
Disconnecting
|
|
||||||
};
|
|
||||||
bool isEncrypted_{false};
|
|
||||||
EventLoop *loop_;
|
|
||||||
std::unique_ptr<Channel> ioChannelPtr_;
|
|
||||||
std::unique_ptr<Socket> socketPtr_;
|
|
||||||
MsgBuffer readBuffer_;
|
|
||||||
std::list<BufferNodePtr> writeBufferList_;
|
|
||||||
void readCallback();
|
|
||||||
void writeCallback();
|
|
||||||
InetAddress localAddr_, peerAddr_;
|
|
||||||
ConnStatus status_{ConnStatus::Connecting};
|
|
||||||
// callbacks
|
|
||||||
RecvMessageCallback recvMsgCallback_;
|
|
||||||
ConnectionCallback connectionCallback_;
|
|
||||||
CloseCallback closeCallback_;
|
|
||||||
WriteCompleteCallback writeCompleteCallback_;
|
|
||||||
HighWaterMarkCallback highWaterMarkCallback_;
|
|
||||||
SSLErrorCallback sslErrorCallback_;
|
|
||||||
void handleClose();
|
|
||||||
void handleError();
|
|
||||||
// virtual void sendInLoop(const std::string &msg);
|
|
||||||
|
|
||||||
void sendFileInLoop(const BufferNodePtr &file);
|
using BufferNodePtr = std::shared_ptr<BufferNode>;
|
||||||
|
enum class ConnStatus {
|
||||||
|
Disconnected,
|
||||||
|
Connecting,
|
||||||
|
Connected,
|
||||||
|
Disconnecting
|
||||||
|
};
|
||||||
|
|
||||||
|
bool isEncrypted_{ false };
|
||||||
|
EventLoop *loop_;
|
||||||
|
std::unique_ptr<Channel> ioChannelPtr_;
|
||||||
|
std::unique_ptr<Socket> socketPtr_;
|
||||||
|
MsgBuffer readBuffer_;
|
||||||
|
std::list<BufferNodePtr> writeBufferList_;
|
||||||
|
|
||||||
|
void readCallback();
|
||||||
|
void writeCallback();
|
||||||
|
|
||||||
|
InetAddress localAddr_, peerAddr_;
|
||||||
|
ConnStatus status_{ ConnStatus::Connecting };
|
||||||
|
// callbacks
|
||||||
|
RecvMessageCallback recvMsgCallback_;
|
||||||
|
ConnectionCallback connectionCallback_;
|
||||||
|
CloseCallback closeCallback_;
|
||||||
|
WriteCompleteCallback writeCompleteCallback_;
|
||||||
|
HighWaterMarkCallback highWaterMarkCallback_;
|
||||||
|
SSLErrorCallback sslErrorCallback_;
|
||||||
|
|
||||||
|
void handleClose();
|
||||||
|
void handleError();
|
||||||
|
// virtual void sendInLoop(const std::string &msg);
|
||||||
|
|
||||||
|
void sendFileInLoop(const BufferNodePtr &file);
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
void sendInLoop(const void *buffer, size_t length);
|
void sendInLoop(const void *buffer, size_t length);
|
||||||
ssize_t writeInLoop(const void *buffer, size_t length);
|
ssize_t writeInLoop(const void *buffer, size_t length);
|
||||||
#else
|
#else
|
||||||
void sendInLoop(const char *buffer, size_t length);
|
void sendInLoop(const char *buffer, size_t length);
|
||||||
ssize_t writeInLoop(const char *buffer, size_t length);
|
ssize_t writeInLoop(const char *buffer, size_t length);
|
||||||
#endif
|
#endif
|
||||||
size_t highWaterMarkLen_;
|
size_t highWaterMarkLen_;
|
||||||
std::string name_;
|
std::string name_;
|
||||||
|
|
||||||
uint64_t sendNum_{0};
|
uint64_t sendNum_{ 0 };
|
||||||
std::mutex sendNumMutex_;
|
std::mutex sendNumMutex_;
|
||||||
|
|
||||||
size_t bytesSent_{0};
|
size_t bytesSent_{ 0 };
|
||||||
size_t bytesReceived_{0};
|
size_t bytesReceived_{ 0 };
|
||||||
|
|
||||||
std::unique_ptr<std::vector<char>> fileBufferPtr_;
|
std::unique_ptr<std::vector<char> > fileBufferPtr_;
|
||||||
|
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
private:
|
private:
|
||||||
void doHandshaking();
|
void doHandshaking();
|
||||||
bool validatePeerCertificate();
|
bool validatePeerCertificate();
|
||||||
struct SSLEncryption
|
|
||||||
{
|
struct SSLEncryption {
|
||||||
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
|
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
|
||||||
// OpenSSL
|
// OpenSSL
|
||||||
std::shared_ptr<SSLContext> sslCtxPtr_;
|
std::shared_ptr<SSLContext> sslCtxPtr_;
|
||||||
std::unique_ptr<SSLConn> sslPtr_;
|
std::unique_ptr<SSLConn> sslPtr_;
|
||||||
std::unique_ptr<std::array<char, 8192>> sendBufferPtr_;
|
std::unique_ptr<std::array<char, 8192> > sendBufferPtr_;
|
||||||
bool isServer_{false};
|
bool isServer_{ false };
|
||||||
bool isUpgrade_{false};
|
bool isUpgrade_{ false };
|
||||||
std::function<void()> upgradeCallback_;
|
std::function<void()> upgradeCallback_;
|
||||||
std::string hostname_;
|
std::string hostname_;
|
||||||
};
|
};
|
||||||
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
|
|
||||||
void startClientEncryptionInLoop(
|
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
|
||||||
std::function<void()> &&callback,
|
|
||||||
bool useOldTLS,
|
void startClientEncryptionInLoop(std::function<void()> &&callback, bool useOldTLS, bool validateCert, const std::string &hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
|
||||||
bool validateCert,
|
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx, std::function<void()> &&callback);
|
||||||
const std::string &hostname,
|
|
||||||
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
|
||||||
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
|
|
||||||
std::function<void()> &&callback);
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@
|
|||||||
Connector::Connector(EventLoop *loop, const InetAddress &addr, bool retry) :
|
Connector::Connector(EventLoop *loop, const InetAddress &addr, bool retry) :
|
||||||
loop_(loop), serverAddr_(addr), retry_(retry) {
|
loop_(loop), serverAddr_(addr), retry_(retry) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Connector::Connector(EventLoop *loop, InetAddress &&addr, bool retry) :
|
Connector::Connector(EventLoop *loop, InetAddress &&addr, bool retry) :
|
||||||
loop_(loop), serverAddr_(std::move(addr)), retry_(retry) {
|
loop_(loop), serverAddr_(std::move(addr)), retry_(retry) {
|
||||||
}
|
}
|
||||||
@ -43,8 +44,10 @@ void Connector::start() {
|
|||||||
connect_ = true;
|
connect_ = true;
|
||||||
loop_->runInLoop([this]() { startInLoop(); });
|
loop_->runInLoop([this]() { startInLoop(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void Connector::restart() {
|
void Connector::restart() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Connector::stop() {
|
void Connector::stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,6 +60,7 @@ void Connector::startInLoop() {
|
|||||||
LOG_DEBUG << "do not connect";
|
LOG_DEBUG << "do not connect";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Connector::connect() {
|
void Connector::connect() {
|
||||||
int sockfd = Socket::createNonblockingSocketOrDie(serverAddr_.family());
|
int sockfd = Socket::createNonblockingSocketOrDie(serverAddr_.family());
|
||||||
errno = 0;
|
errno = 0;
|
||||||
|
@ -46,23 +46,30 @@ protected:
|
|||||||
public:
|
public:
|
||||||
using NewConnectionCallback = std::function<void(int sockfd)>;
|
using NewConnectionCallback = std::function<void(int sockfd)>;
|
||||||
using ConnectionErrorCallback = std::function<void()>;
|
using ConnectionErrorCallback = std::function<void()>;
|
||||||
|
|
||||||
Connector(EventLoop *loop, const InetAddress &addr, bool retry = true);
|
Connector(EventLoop *loop, const InetAddress &addr, bool retry = true);
|
||||||
Connector(EventLoop *loop, InetAddress &&addr, bool retry = true);
|
Connector(EventLoop *loop, InetAddress &&addr, bool retry = true);
|
||||||
|
|
||||||
void setNewConnectionCallback(const NewConnectionCallback &cb) {
|
void setNewConnectionCallback(const NewConnectionCallback &cb) {
|
||||||
newConnectionCallback_ = cb;
|
newConnectionCallback_ = cb;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setNewConnectionCallback(NewConnectionCallback &&cb) {
|
void setNewConnectionCallback(NewConnectionCallback &&cb) {
|
||||||
newConnectionCallback_ = std::move(cb);
|
newConnectionCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setErrorCallback(const ConnectionErrorCallback &cb) {
|
void setErrorCallback(const ConnectionErrorCallback &cb) {
|
||||||
errorCallback_ = cb;
|
errorCallback_ = cb;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setErrorCallback(ConnectionErrorCallback &&cb) {
|
void setErrorCallback(ConnectionErrorCallback &&cb) {
|
||||||
errorCallback_ = std::move(cb);
|
errorCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
const InetAddress &serverAddress() const {
|
const InetAddress &serverAddress() const {
|
||||||
return serverAddr_;
|
return serverAddr_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void start();
|
void start();
|
||||||
void restart();
|
void restart();
|
||||||
void stop();
|
void stop();
|
||||||
@ -70,11 +77,13 @@ public:
|
|||||||
private:
|
private:
|
||||||
NewConnectionCallback newConnectionCallback_;
|
NewConnectionCallback newConnectionCallback_;
|
||||||
ConnectionErrorCallback errorCallback_;
|
ConnectionErrorCallback errorCallback_;
|
||||||
|
|
||||||
enum class Status {
|
enum class Status {
|
||||||
Disconnected,
|
Disconnected,
|
||||||
Connecting,
|
Connecting,
|
||||||
Connected
|
Connected
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr int kMaxRetryDelayMs = 30 * 1000;
|
static constexpr int kMaxRetryDelayMs = 30 * 1000;
|
||||||
static constexpr int kInitRetryDelayMs = 500;
|
static constexpr int kInitRetryDelayMs = 500;
|
||||||
std::shared_ptr<Channel> channelPtr_;
|
std::shared_ptr<Channel> channelPtr_;
|
||||||
|
@ -98,8 +98,7 @@ InetAddress::InetAddress(uint16_t port, bool loopbackOnly, bool ipv6) :
|
|||||||
isUnspecified_ = false;
|
isUnspecified_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) :
|
InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) : isIpV6_(ipv6) {
|
||||||
isIpV6_(ipv6) {
|
|
||||||
if (ipv6) {
|
if (ipv6) {
|
||||||
memset(&addr6_, 0, sizeof(addr6_));
|
memset(&addr6_, 0, sizeof(addr6_));
|
||||||
addr6_.sin6_family = AF_INET6;
|
addr6_.sin6_family = AF_INET6;
|
||||||
|
@ -38,195 +38,76 @@ using sa_family_t = unsigned short;
|
|||||||
using in_addr_t = uint32_t;
|
using in_addr_t = uint32_t;
|
||||||
using uint16_t = unsigned short;
|
using uint16_t = unsigned short;
|
||||||
#else
|
#else
|
||||||
#include <netinet/in.h>
|
|
||||||
#include <arpa/inet.h>
|
#include <arpa/inet.h>
|
||||||
|
#include <netinet/in.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#endif
|
#endif
|
||||||
|
#include <mutex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <mutex>
|
|
||||||
|
|
||||||
/**
|
class InetAddress {
|
||||||
* @brief Wrapper of sockaddr_in. This is an POD interface class.
|
public:
|
||||||
*
|
InetAddress(uint16_t port = 0, bool loopbackOnly = false, bool ipv6 = false);
|
||||||
*/
|
|
||||||
class InetAddress
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
/**
|
|
||||||
* @brief Constructs an endpoint with given port number. Mostly used in
|
|
||||||
* TcpServer listening.
|
|
||||||
*
|
|
||||||
* @param port
|
|
||||||
* @param loopbackOnly
|
|
||||||
* @param ipv6
|
|
||||||
*/
|
|
||||||
InetAddress(uint16_t port = 0,
|
|
||||||
bool loopbackOnly = false,
|
|
||||||
bool ipv6 = false);
|
|
||||||
|
|
||||||
/**
|
InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
|
||||||
* @brief Constructs an endpoint with given ip and port.
|
|
||||||
*
|
|
||||||
* @param ip A IPv4 or IPv6 address.
|
|
||||||
* @param port
|
|
||||||
* @param ipv6
|
|
||||||
*/
|
|
||||||
InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
|
|
||||||
|
|
||||||
/**
|
explicit InetAddress(const struct sockaddr_in &addr) :
|
||||||
* @brief Constructs an endpoint with given struct `sockaddr_in`. Mostly
|
addr_(addr), isUnspecified_(false) {
|
||||||
* used when accepting new connections
|
}
|
||||||
*
|
|
||||||
* @param addr
|
|
||||||
*/
|
|
||||||
explicit InetAddress(const struct sockaddr_in &addr)
|
|
||||||
: addr_(addr), isUnspecified_(false)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
explicit InetAddress(const struct sockaddr_in6 &addr) :
|
||||||
* @brief Constructs an IPv6 endpoint with given struct `sockaddr_in6`.
|
addr6_(addr), isIpV6_(true), isUnspecified_(false) {
|
||||||
* Mostly used when accepting new connections
|
}
|
||||||
*
|
|
||||||
* @param addr
|
|
||||||
*/
|
|
||||||
explicit InetAddress(const struct sockaddr_in6 &addr)
|
|
||||||
: addr6_(addr), isIpV6_(true), isUnspecified_(false)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
sa_family_t family() const {
|
||||||
* @brief Return the sin_family of the endpoint.
|
return addr_.sin_family;
|
||||||
*
|
}
|
||||||
* @return sa_family_t
|
|
||||||
*/
|
|
||||||
sa_family_t family() const
|
|
||||||
{
|
|
||||||
return addr_.sin_family;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
std::string toIp() const;
|
||||||
* @brief Return the IP string of the endpoint.
|
std::string toIpPort() const;
|
||||||
*
|
uint16_t toPort() const;
|
||||||
* @return std::string
|
|
||||||
*/
|
|
||||||
std::string toIp() const;
|
|
||||||
|
|
||||||
/**
|
bool isIpV6() const {
|
||||||
* @brief Return the IP and port string of the endpoint.
|
return isIpV6_;
|
||||||
*
|
}
|
||||||
* @return std::string
|
|
||||||
*/
|
|
||||||
std::string toIpPort() const;
|
|
||||||
|
|
||||||
/**
|
bool isIntranetIp() const;
|
||||||
* @brief Return the port number of the endpoint.
|
bool isLoopbackIp() const;
|
||||||
*
|
|
||||||
* @return uint16_t
|
|
||||||
*/
|
|
||||||
uint16_t toPort() const;
|
|
||||||
|
|
||||||
/**
|
const struct sockaddr *getSockAddr() const {
|
||||||
* @brief Check if the endpoint is IPv4 or IPv6.
|
return static_cast<const struct sockaddr *>((void *)(&addr6_));
|
||||||
*
|
}
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
bool isIpV6() const
|
|
||||||
{
|
|
||||||
return isIpV6_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
void setSockAddrInet6(const struct sockaddr_in6 &addr6) {
|
||||||
* @brief Return true if the endpoint is an intranet endpoint.
|
addr6_ = addr6;
|
||||||
*
|
isIpV6_ = (addr6_.sin6_family == AF_INET6);
|
||||||
* @return true
|
isUnspecified_ = false;
|
||||||
* @return false
|
}
|
||||||
*/
|
|
||||||
bool isIntranetIp() const;
|
|
||||||
|
|
||||||
/**
|
uint32_t ipNetEndian() const;
|
||||||
* @brief Return true if the endpoint is a loopback endpoint.
|
const uint32_t *ip6NetEndian() const;
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
bool isLoopbackIp() const;
|
|
||||||
|
|
||||||
/**
|
uint16_t portNetEndian() const {
|
||||||
* @brief Get the pointer to the sockaddr struct.
|
return addr_.sin_port;
|
||||||
*
|
}
|
||||||
* @return const struct sockaddr*
|
|
||||||
*/
|
|
||||||
const struct sockaddr *getSockAddr() const
|
|
||||||
{
|
|
||||||
return static_cast<const struct sockaddr *>((void *)(&addr6_));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
void setPortNetEndian(uint16_t port) {
|
||||||
* @brief Set the sockaddr_in6 struct in the endpoint.
|
addr_.sin_port = port;
|
||||||
*
|
}
|
||||||
* @param addr6
|
|
||||||
*/
|
|
||||||
void setSockAddrInet6(const struct sockaddr_in6 &addr6)
|
|
||||||
{
|
|
||||||
addr6_ = addr6;
|
|
||||||
isIpV6_ = (addr6_.sin6_family == AF_INET6);
|
|
||||||
isUnspecified_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
inline bool isUnspecified() const {
|
||||||
* @brief Return the integer value of the IP(v4) in net endian byte order.
|
return isUnspecified_;
|
||||||
*
|
}
|
||||||
* @return uint32_t
|
|
||||||
*/
|
|
||||||
uint32_t ipNetEndian() const;
|
|
||||||
|
|
||||||
/**
|
private:
|
||||||
* @brief Return the pointer to the integer value of the IP(v6) in net
|
union {
|
||||||
* endian byte order.
|
struct sockaddr_in addr_;
|
||||||
*
|
struct sockaddr_in6 addr6_;
|
||||||
* @return const uint32_t*
|
};
|
||||||
*/
|
|
||||||
const uint32_t *ip6NetEndian() const;
|
|
||||||
|
|
||||||
/**
|
bool isIpV6_{ false };
|
||||||
* @brief Return the port number in net endian byte order.
|
bool isUnspecified_{ true };
|
||||||
*
|
|
||||||
* @return uint16_t
|
|
||||||
*/
|
|
||||||
uint16_t portNetEndian() const
|
|
||||||
{
|
|
||||||
return addr_.sin_port;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the port number in net endian byte order.
|
|
||||||
*
|
|
||||||
* @param port
|
|
||||||
*/
|
|
||||||
void setPortNetEndian(uint16_t port)
|
|
||||||
{
|
|
||||||
addr_.sin_port = port;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return true if the address is not initalized.
|
|
||||||
*/
|
|
||||||
inline bool isUnspecified() const
|
|
||||||
{
|
|
||||||
return isUnspecified_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
union
|
|
||||||
{
|
|
||||||
struct sockaddr_in addr_;
|
|
||||||
struct sockaddr_in6 addr6_;
|
|
||||||
};
|
|
||||||
bool isIpV6_{false};
|
|
||||||
bool isUnspecified_{true};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // MUDUO_NET_INETADDRESS_H
|
#endif // MUDUO_NET_INETADDRESS_H
|
||||||
|
@ -34,44 +34,20 @@
|
|||||||
#include "core/loops/event_loop.h"
|
#include "core/loops/event_loop.h"
|
||||||
#include "core/net/inet_address.h"
|
#include "core/net/inet_address.h"
|
||||||
|
|
||||||
/**
|
//make it a reference
|
||||||
* @brief This class represents an asynchronous DNS resolver.
|
|
||||||
* @note Although the c-ares library is not essential, it is recommended to
|
|
||||||
* install it for higher performance
|
|
||||||
*/
|
|
||||||
class Resolver
|
class Resolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using Callback = std::function<void(const InetAddress&)>;
|
using Callback = std::function<void(const InetAddress&)>;
|
||||||
|
|
||||||
/**
|
static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr, size_t timeout = 60);
|
||||||
* @brief Create a new DNS resolver.
|
|
||||||
*
|
|
||||||
* @param loop The event loop in which the DNS resolver runs.
|
|
||||||
* @param timeout The timeout in seconds for DNS.
|
|
||||||
* @return std::shared_ptr<Resolver>
|
|
||||||
*/
|
|
||||||
static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr,
|
|
||||||
size_t timeout = 60);
|
|
||||||
|
|
||||||
/**
|
virtual void resolve(const std::string& hostname, const Callback& callback) = 0;
|
||||||
* @brief Resolve an address asynchronously.
|
|
||||||
*
|
|
||||||
* @param hostname
|
|
||||||
* @param callback
|
|
||||||
*/
|
|
||||||
virtual void resolve(const std::string& hostname,
|
|
||||||
const Callback& callback) = 0;
|
|
||||||
|
|
||||||
virtual ~Resolver()
|
virtual ~Resolver()
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Check whether the c-ares library is used.
|
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
static bool isCAresUsed();
|
static bool isCAresUsed();
|
||||||
};
|
};
|
||||||
|
@ -29,8 +29,8 @@
|
|||||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#include "ares_resolver.h"
|
#include "ares_resolver.h"
|
||||||
#include <ares.h>
|
|
||||||
#include "core/loops/channel.h"
|
#include "core/loops/channel.h"
|
||||||
|
#include <ares.h>
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <winsock2.h>
|
#include <winsock2.h>
|
||||||
#else
|
#else
|
||||||
@ -67,14 +67,14 @@ bool Resolver::isCAresUsed() {
|
|||||||
AresResolver::LibraryInitializer::LibraryInitializer() {
|
AresResolver::LibraryInitializer::LibraryInitializer() {
|
||||||
ares_library_init(ARES_LIB_INIT_ALL);
|
ares_library_init(ARES_LIB_INIT_ALL);
|
||||||
}
|
}
|
||||||
|
|
||||||
AresResolver::LibraryInitializer::~LibraryInitializer() {
|
AresResolver::LibraryInitializer::~LibraryInitializer() {
|
||||||
ares_library_cleanup();
|
ares_library_cleanup();
|
||||||
}
|
}
|
||||||
|
|
||||||
AresResolver::LibraryInitializer AresResolver::libraryInitializer_;
|
AresResolver::LibraryInitializer AresResolver::libraryInitializer_;
|
||||||
|
|
||||||
std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop,
|
std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop, size_t timeout) {
|
||||||
size_t timeout) {
|
|
||||||
return std::make_shared<AresResolver>(loop, timeout);
|
return std::make_shared<AresResolver>(loop, timeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,6 +84,7 @@ AresResolver::AresResolver(EventLoop *loop, size_t timeout) :
|
|||||||
loop_ = getLoop();
|
loop_ = getLoop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AresResolver::init() {
|
void AresResolver::init() {
|
||||||
if (!ctx_) {
|
if (!ctx_) {
|
||||||
struct ares_options options;
|
struct ares_options options;
|
||||||
@ -108,14 +109,16 @@ void AresResolver::init() {
|
|||||||
this);
|
this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AresResolver::~AresResolver() {
|
AresResolver::~AresResolver() {
|
||||||
if (ctx_)
|
if (ctx_)
|
||||||
ares_destroy(ctx_);
|
ares_destroy(ctx_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AresResolver::resolveInLoop(const std::string &hostname,
|
void AresResolver::resolveInLoop(const std::string &hostname, const Callback &cb) {
|
||||||
const Callback &cb) {
|
|
||||||
loop_->assertInLoopThread();
|
loop_->assertInLoopThread();
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
if (hostname == "localhost") {
|
if (hostname == "localhost") {
|
||||||
const static InetAddress localhost_{ "127.0.0.1", 0 };
|
const static InetAddress localhost_{ "127.0.0.1", 0 };
|
||||||
@ -123,20 +126,17 @@ void AresResolver::resolveInLoop(const std::string &hostname,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
init();
|
init();
|
||||||
QueryData *queryData = new QueryData(this, cb, hostname);
|
QueryData *queryData = new QueryData(this, cb, hostname);
|
||||||
ares_gethostbyname(ctx_,
|
ares_gethostbyname(ctx_, hostname.c_str(), AF_INET, &AresResolver::ares_hostcallback_, queryData);
|
||||||
hostname.c_str(),
|
|
||||||
AF_INET,
|
|
||||||
&AresResolver::ares_hostcallback_,
|
|
||||||
queryData);
|
|
||||||
struct timeval tv;
|
struct timeval tv;
|
||||||
struct timeval *tvp = ares_timeout(ctx_, NULL, &tv);
|
struct timeval *tvp = ares_timeout(ctx_, NULL, &tv);
|
||||||
double timeout = getSeconds(tvp);
|
double timeout = getSeconds(tvp);
|
||||||
|
|
||||||
// LOG_DEBUG << "timeout " << timeout << " active " << timerActive_;
|
// LOG_DEBUG << "timeout " << timeout << " active " << timerActive_;
|
||||||
if (!timerActive_ && timeout >= 0.0) {
|
if (!timerActive_ && timeout >= 0.0) {
|
||||||
loop_->runAfter(timeout,
|
loop_->runAfter(timeout, std::bind(&AresResolver::onTimer, shared_from_this()));
|
||||||
std::bind(&AresResolver::onTimer, shared_from_this()));
|
|
||||||
timerActive_ = true;
|
timerActive_ = true;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@ -161,18 +161,18 @@ void AresResolver::onTimer() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AresResolver::onQueryResult(int status,
|
void AresResolver::onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback) {
|
||||||
struct hostent *result,
|
|
||||||
const std::string &hostname,
|
|
||||||
const Callback &callback) {
|
|
||||||
LOG_TRACE << "onQueryResult " << status;
|
LOG_TRACE << "onQueryResult " << status;
|
||||||
struct sockaddr_in addr;
|
struct sockaddr_in addr;
|
||||||
memset(&addr, 0, sizeof addr);
|
memset(&addr, 0, sizeof addr);
|
||||||
addr.sin_family = AF_INET;
|
addr.sin_family = AF_INET;
|
||||||
addr.sin_port = 0;
|
addr.sin_port = 0;
|
||||||
|
|
||||||
if (result) {
|
if (result) {
|
||||||
addr.sin_addr = *reinterpret_cast<in_addr *>(result->h_addr);
|
addr.sin_addr = *reinterpret_cast<in_addr *>(result->h_addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
InetAddress inet(addr);
|
InetAddress inet(addr);
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(globalMutex());
|
std::lock_guard<std::mutex> lock(globalMutex());
|
||||||
@ -180,6 +180,7 @@ void AresResolver::onQueryResult(int status,
|
|||||||
addrItem.first = addr.sin_addr;
|
addrItem.first = addr.sin_addr;
|
||||||
addrItem.second = Date::date();
|
addrItem.second = Date::date();
|
||||||
}
|
}
|
||||||
|
|
||||||
callback(inet);
|
callback(inet);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,6 +199,7 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
|
|||||||
loop_->assertInLoopThread();
|
loop_->assertInLoopThread();
|
||||||
ChannelList::iterator it = channels_.find(sockfd);
|
ChannelList::iterator it = channels_.find(sockfd);
|
||||||
assert(it != channels_.end());
|
assert(it != channels_.end());
|
||||||
|
|
||||||
if (read) {
|
if (read) {
|
||||||
// update
|
// update
|
||||||
// if (write) { } else { }
|
// if (write) { } else { }
|
||||||
@ -209,17 +211,12 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AresResolver::ares_hostcallback_(void *data,
|
void AresResolver::ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent) {
|
||||||
int status,
|
|
||||||
int timeouts,
|
|
||||||
struct hostent *hostent) {
|
|
||||||
(void)timeouts;
|
(void)timeouts;
|
||||||
QueryData *query = static_cast<QueryData *>(data);
|
QueryData *query = static_cast<QueryData *>(data);
|
||||||
|
|
||||||
query->owner_->onQueryResult(status,
|
query->owner_->onQueryResult(status, hostent, query->hostname_, query->callback_);
|
||||||
hostent,
|
|
||||||
query->hostname_,
|
|
||||||
query->callback_);
|
|
||||||
delete query;
|
delete query;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,6 +239,7 @@ void AresResolver::ares_sock_statecallback_(void *data,
|
|||||||
#endif
|
#endif
|
||||||
int read,
|
int read,
|
||||||
int write) {
|
int write) {
|
||||||
|
|
||||||
LOG_TRACE << "sockfd=" << sockfd << " read=" << read << " write=" << write;
|
LOG_TRACE << "sockfd=" << sockfd << " read=" << read << " write=" << write;
|
||||||
static_cast<AresResolver *>(data)->onSockStateChange(sockfd, read, write);
|
static_cast<AresResolver *>(data)->onSockStateChange(sockfd, read, write);
|
||||||
}
|
}
|
||||||
|
@ -36,14 +36,15 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
// Resolver will be a ref
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
struct hostent;
|
struct hostent;
|
||||||
struct ares_channeldata;
|
struct ares_channeldata;
|
||||||
using ares_channel = struct ares_channeldata *;
|
using ares_channel = struct ares_channeldata *;
|
||||||
}
|
}
|
||||||
|
|
||||||
class AresResolver : public Resolver,
|
class AresResolver : public Resolver, public std::enable_shared_from_this<AresResolver> {
|
||||||
public std::enable_shared_from_this<AresResolver> {
|
|
||||||
protected:
|
protected:
|
||||||
AresResolver(const AresResolver &) = delete;
|
AresResolver(const AresResolver &) = delete;
|
||||||
AresResolver &operator=(const AresResolver &) = delete;
|
AresResolver &operator=(const AresResolver &) = delete;
|
||||||
@ -55,9 +56,9 @@ public:
|
|||||||
AresResolver(EventLoop *loop, size_t timeout);
|
AresResolver(EventLoop *loop, size_t timeout);
|
||||||
~AresResolver();
|
~AresResolver();
|
||||||
|
|
||||||
virtual void resolve(const std::string &hostname,
|
virtual void resolve(const std::string &hostname, const Callback &cb) override {
|
||||||
const Callback &cb) override {
|
|
||||||
bool cached = false;
|
bool cached = false;
|
||||||
|
|
||||||
InetAddress inet;
|
InetAddress inet;
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(globalMutex());
|
std::lock_guard<std::mutex> lock(globalMutex());
|
||||||
@ -76,10 +77,12 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cached) {
|
if (cached) {
|
||||||
cb(inet);
|
cb(inet);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (loop_->isInLoopThread()) {
|
if (loop_->isInLoopThread()) {
|
||||||
resolveInLoop(hostname, cb);
|
resolveInLoop(hostname, cb);
|
||||||
} else {
|
} else {
|
||||||
@ -94,14 +97,15 @@ private:
|
|||||||
AresResolver *owner_;
|
AresResolver *owner_;
|
||||||
Callback callback_;
|
Callback callback_;
|
||||||
std::string hostname_;
|
std::string hostname_;
|
||||||
QueryData(AresResolver *o,
|
|
||||||
const Callback &cb,
|
QueryData(AresResolver *o, const Callback &cb, const std::string &hostname) :
|
||||||
const std::string &hostname) :
|
|
||||||
owner_(o), callback_(cb), hostname_(hostname) {
|
owner_(o), callback_(cb), hostname_(hostname) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void resolveInLoop(const std::string &hostname, const Callback &cb);
|
void resolveInLoop(const std::string &hostname, const Callback &cb);
|
||||||
void init();
|
void init();
|
||||||
|
|
||||||
EventLoop *loop_;
|
EventLoop *loop_;
|
||||||
ares_channel ctx_{ nullptr };
|
ares_channel ctx_{ nullptr };
|
||||||
bool timerActive_{ false };
|
bool timerActive_{ false };
|
||||||
@ -109,36 +113,33 @@ private:
|
|||||||
ChannelList channels_;
|
ChannelList channels_;
|
||||||
static std::unordered_map<std::string,
|
static std::unordered_map<std::string,
|
||||||
std::pair<struct in_addr, Date> > &
|
std::pair<struct in_addr, Date> > &
|
||||||
|
|
||||||
globalCache() {
|
globalCache() {
|
||||||
static std::unordered_map<std::string,
|
static std::unordered_map<std::string, std::pair<struct in_addr, Date> > dnsCache;
|
||||||
std::pair<struct in_addr, Date> >
|
|
||||||
dnsCache;
|
|
||||||
return dnsCache;
|
return dnsCache;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::mutex &globalMutex() {
|
static std::mutex &globalMutex() {
|
||||||
static std::mutex mutex_;
|
static std::mutex mutex_;
|
||||||
return mutex_;
|
return mutex_;
|
||||||
}
|
}
|
||||||
|
|
||||||
static EventLoop *getLoop() {
|
static EventLoop *getLoop() {
|
||||||
static EventLoopThread loopThread;
|
static EventLoopThread loopThread;
|
||||||
loopThread.run();
|
loopThread.run();
|
||||||
return loopThread.getLoop();
|
return loopThread.getLoop();
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t timeout_{ 60 };
|
const size_t timeout_{ 60 };
|
||||||
|
|
||||||
void onRead(int sockfd);
|
void onRead(int sockfd);
|
||||||
void onTimer();
|
void onTimer();
|
||||||
void onQueryResult(int status,
|
void onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback);
|
||||||
struct hostent *result,
|
|
||||||
const std::string &hostname,
|
|
||||||
const Callback &callback);
|
|
||||||
void onSockCreate(int sockfd, int type);
|
void onSockCreate(int sockfd, int type);
|
||||||
void onSockStateChange(int sockfd, bool read, bool write);
|
void onSockStateChange(int sockfd, bool read, bool write);
|
||||||
|
|
||||||
static void ares_hostcallback_(void *data,
|
static void ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent);
|
||||||
int status,
|
|
||||||
int timeouts,
|
|
||||||
struct hostent *hostent);
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data);
|
static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data);
|
||||||
#else
|
#else
|
||||||
@ -152,9 +153,11 @@ private:
|
|||||||
#endif
|
#endif
|
||||||
int read,
|
int read,
|
||||||
int write);
|
int write);
|
||||||
|
|
||||||
struct LibraryInitializer {
|
struct LibraryInitializer {
|
||||||
LibraryInitializer();
|
LibraryInitializer();
|
||||||
~LibraryInitializer();
|
~LibraryInitializer();
|
||||||
};
|
};
|
||||||
|
|
||||||
static LibraryInitializer libraryInitializer_;
|
static LibraryInitializer libraryInitializer_;
|
||||||
};
|
};
|
||||||
|
@ -44,9 +44,11 @@ std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *,
|
|||||||
size_t timeout) {
|
size_t timeout) {
|
||||||
return std::make_shared<NormalResolver>(timeout);
|
return std::make_shared<NormalResolver>(timeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Resolver::isCAresUsed() {
|
bool Resolver::isCAresUsed() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NormalResolver::resolve(const std::string &hostname,
|
void NormalResolver::resolve(const std::string &hostname,
|
||||||
const Callback &callback) {
|
const Callback &callback) {
|
||||||
{
|
{
|
||||||
|
@ -35,10 +35,12 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
//Resolver will be a ref
|
||||||
|
|
||||||
constexpr size_t kResolveBufferLength{ 16 * 1024 };
|
constexpr size_t kResolveBufferLength{ 16 * 1024 };
|
||||||
|
|
||||||
class NormalResolver : public Resolver,
|
class NormalResolver : public Resolver, public std::enable_shared_from_this<NormalResolver> {
|
||||||
public std::enable_shared_from_this<NormalResolver> {
|
|
||||||
protected:
|
protected:
|
||||||
NormalResolver(const NormalResolver &) = delete;
|
NormalResolver(const NormalResolver &) = delete;
|
||||||
NormalResolver &operator=(const NormalResolver &) = delete;
|
NormalResolver &operator=(const NormalResolver &) = delete;
|
||||||
@ -56,25 +58,23 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static std::unordered_map<std::string,
|
static std::unordered_map<std::string, std::pair<InetAddress, Date> > &globalCache() {
|
||||||
std::pair<InetAddress, Date> > &
|
static std::unordered_map<std::string, std::pair<InetAddress, Date> > dnsCache_;
|
||||||
globalCache() {
|
|
||||||
static std::unordered_map<
|
|
||||||
std::string,
|
|
||||||
std::pair<InetAddress, Date> >
|
|
||||||
dnsCache_;
|
|
||||||
return dnsCache_;
|
return dnsCache_;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::mutex &globalMutex() {
|
static std::mutex &globalMutex() {
|
||||||
static std::mutex mutex_;
|
static std::mutex mutex_;
|
||||||
return mutex_;
|
return mutex_;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ConcurrentTaskQueue &concurrentTaskQueue() {
|
static ConcurrentTaskQueue &concurrentTaskQueue() {
|
||||||
static ConcurrentTaskQueue queue(
|
static ConcurrentTaskQueue queue(
|
||||||
std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(),
|
std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(),
|
||||||
"Dns Queue");
|
"Dns Queue");
|
||||||
return queue;
|
return queue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t timeout_;
|
const size_t timeout_;
|
||||||
std::vector<char> resolveBuffer_;
|
std::vector<char> resolveBuffer_;
|
||||||
};
|
};
|
||||||
|
@ -29,9 +29,9 @@
|
|||||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#include "socket.h"
|
#include "socket.h"
|
||||||
|
#include "core/log/logger.h"
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
#include "core/log/logger.h"
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <ws2tcpip.h>
|
#include <ws2tcpip.h>
|
||||||
#else
|
#else
|
||||||
@ -39,7 +39,6 @@
|
|||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
bool Socket::isSelfConnect(int sockfd) {
|
bool Socket::isSelfConnect(int sockfd) {
|
||||||
struct sockaddr_in6 localaddr = getLocalAddr(sockfd);
|
struct sockaddr_in6 localaddr = getLocalAddr(sockfd);
|
||||||
struct sockaddr_in6 peeraddr = getPeerAddr(sockfd);
|
struct sockaddr_in6 peeraddr = getPeerAddr(sockfd);
|
||||||
@ -75,6 +74,7 @@ void Socket::bindAddress(const InetAddress &localaddr) {
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Socket::listen() {
|
void Socket::listen() {
|
||||||
assert(sockFd_ > 0);
|
assert(sockFd_ > 0);
|
||||||
int ret = ::listen(sockFd_, SOMAXCONN);
|
int ret = ::listen(sockFd_, SOMAXCONN);
|
||||||
@ -83,6 +83,7 @@ void Socket::listen() {
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int Socket::accept(InetAddress *peeraddr) {
|
int Socket::accept(InetAddress *peeraddr) {
|
||||||
struct sockaddr_in6 addr6;
|
struct sockaddr_in6 addr6;
|
||||||
memset(&addr6, 0, sizeof(addr6));
|
memset(&addr6, 0, sizeof(addr6));
|
||||||
@ -102,6 +103,7 @@ int Socket::accept(InetAddress *peeraddr) {
|
|||||||
}
|
}
|
||||||
return connfd;
|
return connfd;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Socket::closeWrite() {
|
void Socket::closeWrite() {
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
if (::shutdown(sockFd_, SHUT_WR) < 0)
|
if (::shutdown(sockFd_, SHUT_WR) < 0)
|
||||||
@ -112,6 +114,7 @@ void Socket::closeWrite() {
|
|||||||
LOG_SYSERR << "sockets::shutdownWrite";
|
LOG_SYSERR << "sockets::shutdownWrite";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int Socket::read(char *buffer, uint64_t len) {
|
int Socket::read(char *buffer, uint64_t len) {
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
return ::read(sockFd_, buffer, len);
|
return ::read(sockFd_, buffer, len);
|
||||||
|
@ -32,8 +32,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "core/net/inet_address.h"
|
|
||||||
#include "core/log/logger.h"
|
#include "core/log/logger.h"
|
||||||
|
#include "core/net/inet_address.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
@ -41,7 +41,6 @@
|
|||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
|
|
||||||
class Socket {
|
class Socket {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Socket(const Socket &) = delete;
|
Socket(const Socket &) = delete;
|
||||||
Socket &operator=(const Socket &) = delete;
|
Socket &operator=(const Socket &) = delete;
|
||||||
@ -63,6 +62,7 @@ public:
|
|||||||
LOG_SYSERR << "sockets::createNonblockingOrDie";
|
LOG_SYSERR << "sockets::createNonblockingOrDie";
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TRACE << "sock=" << sock;
|
LOG_TRACE << "sock=" << sock;
|
||||||
return sock;
|
return sock;
|
||||||
}
|
}
|
||||||
@ -85,15 +85,9 @@ public:
|
|||||||
|
|
||||||
static int connect(int sockfd, const InetAddress &addr) {
|
static int connect(int sockfd, const InetAddress &addr) {
|
||||||
if (addr.isIpV6())
|
if (addr.isIpV6())
|
||||||
return ::connect(sockfd,
|
return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in6)));
|
||||||
addr.getSockAddr(),
|
|
||||||
static_cast<socklen_t>(
|
|
||||||
sizeof(struct sockaddr_in6)));
|
|
||||||
else
|
else
|
||||||
return ::connect(sockfd,
|
return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in)));
|
||||||
addr.getSockAddr(),
|
|
||||||
static_cast<socklen_t>(
|
|
||||||
sizeof(struct sockaddr_in)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isSelfConnect(int sockfd);
|
static bool isSelfConnect(int sockfd);
|
||||||
@ -101,7 +95,9 @@ public:
|
|||||||
explicit Socket(int sockfd) :
|
explicit Socket(int sockfd) :
|
||||||
sockFd_(sockfd) {
|
sockFd_(sockfd) {
|
||||||
}
|
}
|
||||||
|
|
||||||
~Socket();
|
~Socket();
|
||||||
|
|
||||||
/// abort if address in use
|
/// abort if address in use
|
||||||
void bindAddress(const InetAddress &localaddr);
|
void bindAddress(const InetAddress &localaddr);
|
||||||
/// abort if address in use
|
/// abort if address in use
|
||||||
@ -109,9 +105,11 @@ public:
|
|||||||
int accept(InetAddress *peeraddr);
|
int accept(InetAddress *peeraddr);
|
||||||
void closeWrite();
|
void closeWrite();
|
||||||
int read(char *buffer, uint64_t len);
|
int read(char *buffer, uint64_t len);
|
||||||
|
|
||||||
int fd() {
|
int fd() {
|
||||||
return sockFd_;
|
return sockFd_;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct sockaddr_in6 getLocalAddr(int sockfd);
|
static struct sockaddr_in6 getLocalAddr(int sockfd);
|
||||||
static struct sockaddr_in6 getPeerAddr(int sockfd);
|
static struct sockaddr_in6 getPeerAddr(int sockfd);
|
||||||
|
|
||||||
|
@ -54,9 +54,7 @@ TcpClient::IgnoreSigPipe TcpClient::initObj;
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
static void defaultConnectionCallback(const TcpConnectionPtr &conn) {
|
static void defaultConnectionCallback(const TcpConnectionPtr &conn) {
|
||||||
LOG_TRACE << conn->localAddr().toIpPort() << " -> "
|
LOG_TRACE << conn->localAddr().toIpPort() << " -> " << conn->peerAddr().toIpPort() << " is " << (conn->connected() ? "UP" : "DOWN");
|
||||||
<< conn->peerAddr().toIpPort() << " is "
|
|
||||||
<< (conn->connected() ? "UP" : "DOWN");
|
|
||||||
// do not call conn->forceClose(), because some users want to register
|
// do not call conn->forceClose(), because some users want to register
|
||||||
// message callback only.
|
// message callback only.
|
||||||
}
|
}
|
||||||
@ -65,9 +63,7 @@ static void defaultMessageCallback(const TcpConnectionPtr &, MsgBuffer *buf) {
|
|||||||
buf->retrieveAll();
|
buf->retrieveAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
TcpClient::TcpClient(EventLoop *loop,
|
TcpClient::TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg) :
|
||||||
const InetAddress &serverAddr,
|
|
||||||
const std::string &nameArg) :
|
|
||||||
loop_(loop),
|
loop_(loop),
|
||||||
connector_(new Connector(loop, serverAddr, false)),
|
connector_(new Connector(loop, serverAddr, false)),
|
||||||
name_(nameArg),
|
name_(nameArg),
|
||||||
@ -75,23 +71,27 @@ TcpClient::TcpClient(EventLoop *loop,
|
|||||||
messageCallback_(defaultMessageCallback),
|
messageCallback_(defaultMessageCallback),
|
||||||
retry_(false),
|
retry_(false),
|
||||||
connect_(true) {
|
connect_(true) {
|
||||||
connector_->setNewConnectionCallback(
|
|
||||||
std::bind(&TcpClient::newConnection, this, _1));
|
connector_->setNewConnectionCallback(std::bind(&TcpClient::newConnection, this, _1));
|
||||||
|
|
||||||
connector_->setErrorCallback([this]() {
|
connector_->setErrorCallback([this]() {
|
||||||
if (connectionErrorCallback_) {
|
if (connectionErrorCallback_) {
|
||||||
connectionErrorCallback_();
|
connectionErrorCallback_();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector ";
|
LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector ";
|
||||||
}
|
}
|
||||||
|
|
||||||
TcpClient::~TcpClient() {
|
TcpClient::~TcpClient() {
|
||||||
LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector ";
|
LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector ";
|
||||||
|
|
||||||
TcpConnectionImplPtr conn;
|
TcpConnectionImplPtr conn;
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
conn = std::dynamic_pointer_cast<TcpConnectionImpl>(connection_);
|
conn = std::dynamic_pointer_cast<TcpConnectionImpl>(connection_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (conn) {
|
if (conn) {
|
||||||
assert(loop_ == conn->getLoop());
|
assert(loop_ == conn->getLoop());
|
||||||
// TODO: not 100% safe, if we are in different thread
|
// TODO: not 100% safe, if we are in different thread
|
||||||
@ -104,6 +104,7 @@ TcpClient::~TcpClient() {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
conn->forceClose();
|
conn->forceClose();
|
||||||
} else {
|
} else {
|
||||||
/// TODO need test in this condition
|
/// TODO need test in this condition
|
||||||
@ -142,39 +143,34 @@ void TcpClient::newConnection(int sockfd) {
|
|||||||
// TODO poll with zero timeout to double confirm the new connection
|
// TODO poll with zero timeout to double confirm the new connection
|
||||||
// TODO use make_shared if necessary
|
// TODO use make_shared if necessary
|
||||||
std::shared_ptr<TcpConnectionImpl> conn;
|
std::shared_ptr<TcpConnectionImpl> conn;
|
||||||
|
|
||||||
if (sslCtxPtr_) {
|
if (sslCtxPtr_) {
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
conn = std::make_shared<TcpConnectionImpl>(loop_,
|
conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr, sslCtxPtr_, false, validateCert_, SSLHostName_);
|
||||||
sockfd,
|
|
||||||
localAddr,
|
|
||||||
peerAddr,
|
|
||||||
sslCtxPtr_,
|
|
||||||
false,
|
|
||||||
validateCert_,
|
|
||||||
SSLHostName_);
|
|
||||||
#else
|
#else
|
||||||
LOG_FATAL << "OpenSSL is not found in your system!";
|
LOG_FATAL << "OpenSSL is not found in your system!";
|
||||||
abort();
|
abort();
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
conn = std::make_shared<TcpConnectionImpl>(loop_,
|
conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr);
|
||||||
sockfd,
|
|
||||||
localAddr,
|
|
||||||
peerAddr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn->setConnectionCallback(connectionCallback_);
|
conn->setConnectionCallback(connectionCallback_);
|
||||||
conn->setRecvMsgCallback(messageCallback_);
|
conn->setRecvMsgCallback(messageCallback_);
|
||||||
conn->setWriteCompleteCallback(writeCompleteCallback_);
|
conn->setWriteCompleteCallback(writeCompleteCallback_);
|
||||||
conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1));
|
conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1));
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
connection_ = conn;
|
connection_ = conn;
|
||||||
}
|
}
|
||||||
|
|
||||||
conn->setSSLErrorCallback([this](SSLError err) {
|
conn->setSSLErrorCallback([this](SSLError err) {
|
||||||
if (sslErrorCallback_) {
|
if (sslErrorCallback_) {
|
||||||
sslErrorCallback_(err);
|
sslErrorCallback_(err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
conn->connectEstablished();
|
conn->connectEstablished();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,30 +184,23 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn) {
|
|||||||
connection_.reset();
|
connection_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
loop_->queueInLoop(
|
loop_->queueInLoop(std::bind(&TcpConnectionImpl::connectDestroyed, std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
|
||||||
std::bind(&TcpConnectionImpl::connectDestroyed,
|
|
||||||
std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
|
|
||||||
if (retry_ && connect_) {
|
if (retry_ && connect_) {
|
||||||
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to "
|
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " << connector_->serverAddress().toIpPort();
|
||||||
<< connector_->serverAddress().toIpPort();
|
|
||||||
connector_->restart();
|
connector_->restart();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TcpClient::enableSSL(
|
void TcpClient::enableSSL(bool useOldTLS, bool validateCert, std::string hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
||||||
bool useOldTLS,
|
|
||||||
bool validateCert,
|
|
||||||
std::string hostname,
|
|
||||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
/* Create a new OpenSSL context */
|
/* Create a new OpenSSL context */
|
||||||
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds);
|
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds);
|
||||||
|
|
||||||
validateCert_ = validateCert;
|
validateCert_ = validateCert;
|
||||||
|
|
||||||
if (!hostname.empty()) {
|
if (!hostname.empty()) {
|
||||||
std::transform(hostname.begin(),
|
std::transform(hostname.begin(), hostname.end(), hostname.begin(), tolower);
|
||||||
hostname.end(),
|
|
||||||
hostname.begin(),
|
|
||||||
tolower);
|
|
||||||
SSLHostName_ = std::move(hostname);
|
SSLHostName_ = std::move(hostname);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,10 +42,7 @@
|
|||||||
class Connector;
|
class Connector;
|
||||||
using ConnectorPtr = std::shared_ptr<Connector>;
|
using ConnectorPtr = std::shared_ptr<Connector>;
|
||||||
class SSLContext;
|
class SSLContext;
|
||||||
/**
|
|
||||||
* @brief This class represents a TCP client.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
class TcpClient {
|
class TcpClient {
|
||||||
protected:
|
protected:
|
||||||
TcpClient(const TcpClient &) = delete;
|
TcpClient(const TcpClient &) = delete;
|
||||||
@ -55,88 +52,35 @@ protected:
|
|||||||
TcpClient &operator=(TcpClient &&) noexcept(true) = default;
|
TcpClient &operator=(TcpClient &&) noexcept(true) = default;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg);
|
||||||
* @brief Construct a new TCP client instance.
|
|
||||||
*
|
|
||||||
* @param loop The event loop in which the client runs.
|
|
||||||
* @param serverAddr The address of the server.
|
|
||||||
* @param nameArg The name of the client.
|
|
||||||
*/
|
|
||||||
TcpClient(EventLoop *loop,
|
|
||||||
const InetAddress &serverAddr,
|
|
||||||
const std::string &nameArg);
|
|
||||||
~TcpClient();
|
~TcpClient();
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Connect to the server.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void connect();
|
void connect();
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Disconnect from the server.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void disconnect();
|
void disconnect();
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Stop connecting to the server.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void stop();
|
void stop();
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the TCP connection to the server.
|
|
||||||
*
|
|
||||||
* @return TcpConnectionPtr
|
|
||||||
*/
|
|
||||||
TcpConnectionPtr connection() const {
|
TcpConnectionPtr connection() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
return connection_;
|
return connection_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the event loop.
|
|
||||||
*
|
|
||||||
* @return EventLoop*
|
|
||||||
*/
|
|
||||||
EventLoop *getLoop() const {
|
EventLoop *getLoop() const {
|
||||||
return loop_;
|
return loop_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Check whether the client re-connect to the server.
|
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
bool retry() const {
|
bool retry() const {
|
||||||
return retry_;
|
return retry_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Enable retrying.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void enableRetry() {
|
void enableRetry() {
|
||||||
retry_ = true;
|
retry_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the name of the client.
|
|
||||||
*
|
|
||||||
* @return const std::string&
|
|
||||||
*/
|
|
||||||
const std::string &name() const {
|
const std::string &name() const {
|
||||||
return name_;
|
return name_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the connection callback.
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when the connection to the server is
|
|
||||||
* established or closed.
|
|
||||||
*/
|
|
||||||
void setConnectionCallback(const ConnectionCallback &cb) {
|
void setConnectionCallback(const ConnectionCallback &cb) {
|
||||||
connectionCallback_ = cb;
|
connectionCallback_ = cb;
|
||||||
}
|
}
|
||||||
@ -144,37 +88,18 @@ public:
|
|||||||
connectionCallback_ = std::move(cb);
|
connectionCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the connection error callback.
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when an error occurs during connecting
|
|
||||||
* to the server.
|
|
||||||
*/
|
|
||||||
void setConnectionErrorCallback(const ConnectionErrorCallback &cb) {
|
void setConnectionErrorCallback(const ConnectionErrorCallback &cb) {
|
||||||
connectionErrorCallback_ = cb;
|
connectionErrorCallback_ = cb;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the message callback.
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when some data is received from the
|
|
||||||
* server.
|
|
||||||
*/
|
|
||||||
void setMessageCallback(const RecvMessageCallback &cb) {
|
void setMessageCallback(const RecvMessageCallback &cb) {
|
||||||
messageCallback_ = cb;
|
messageCallback_ = cb;
|
||||||
}
|
}
|
||||||
void setMessageCallback(RecvMessageCallback &&cb) {
|
void setMessageCallback(RecvMessageCallback &&cb) {
|
||||||
messageCallback_ = std::move(cb);
|
messageCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
/// Set write complete callback.
|
|
||||||
/// Not thread safe.
|
|
||||||
|
|
||||||
/**
|
/// Not thread safe.
|
||||||
* @brief Set the write complete callback.
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when data to send is written to the
|
|
||||||
* socket.
|
|
||||||
*/
|
|
||||||
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
||||||
writeCompleteCallback_ = cb;
|
writeCompleteCallback_ = cb;
|
||||||
}
|
}
|
||||||
@ -182,10 +107,6 @@ public:
|
|||||||
writeCompleteCallback_ = std::move(cb);
|
writeCompleteCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the callback for errors of SSL
|
|
||||||
* @param cb The callback is called when an SSL error occurs.
|
|
||||||
*/
|
|
||||||
void setSSLErrorCallback(const SSLErrorCallback &cb) {
|
void setSSLErrorCallback(const SSLErrorCallback &cb) {
|
||||||
sslErrorCallback_ = cb;
|
sslErrorCallback_ = cb;
|
||||||
}
|
}
|
||||||
@ -193,24 +114,7 @@ public:
|
|||||||
sslErrorCallback_ = std::move(cb);
|
sslErrorCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
void enableSSL(bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
||||||
* @brief Enable SSL encryption.
|
|
||||||
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
|
|
||||||
* client.
|
|
||||||
* @param validateCert If true, we try to validate if the peer's SSL cert
|
|
||||||
* is valid.
|
|
||||||
* @param hostname The server hostname for SNI. If it is empty, the SNI is
|
|
||||||
* not used.
|
|
||||||
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
|
|
||||||
* OpenSSL.
|
|
||||||
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
|
|
||||||
* 2020. And it's a good practice to only use TLS 1.2 and above.
|
|
||||||
*/
|
|
||||||
void enableSSL(bool useOldTLS = false,
|
|
||||||
bool validateCert = true,
|
|
||||||
std::string hostname = "",
|
|
||||||
const std::vector<std::pair<std::string, std::string> >
|
|
||||||
&sslConfCmds = {});
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Not thread safe, but in loop
|
/// Not thread safe, but in loop
|
||||||
@ -234,6 +138,7 @@ private:
|
|||||||
std::shared_ptr<SSLContext> sslCtxPtr_;
|
std::shared_ptr<SSLContext> sslCtxPtr_;
|
||||||
bool validateCert_{ false };
|
bool validateCert_{ false };
|
||||||
std::string SSLHostName_;
|
std::string SSLHostName_;
|
||||||
|
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
class IgnoreSigPipe {
|
class IgnoreSigPipe {
|
||||||
public:
|
public:
|
||||||
|
@ -39,26 +39,13 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
class SSLContext;
|
class SSLContext;
|
||||||
std::shared_ptr<SSLContext> newSSLServerContext(
|
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
||||||
const std::string &certPath,
|
|
||||||
const std::string &keyPath,
|
|
||||||
bool useOldTLS = false,
|
|
||||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
|
||||||
/**
|
|
||||||
* @brief This class represents a TCP connection.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
class TcpConnection {
|
class TcpConnection {
|
||||||
public:
|
public:
|
||||||
TcpConnection() = default;
|
TcpConnection() = default;
|
||||||
virtual ~TcpConnection(){};
|
virtual ~TcpConnection(){};
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Send some data to the peer.
|
|
||||||
*
|
|
||||||
* @param msg
|
|
||||||
* @param len
|
|
||||||
*/
|
|
||||||
virtual void send(const char *msg, size_t len) = 0;
|
virtual void send(const char *msg, size_t len) = 0;
|
||||||
virtual void send(const void *msg, size_t len) = 0;
|
virtual void send(const void *msg, size_t len) = 0;
|
||||||
virtual void send(const std::string &msg) = 0;
|
virtual void send(const std::string &msg) = 0;
|
||||||
@ -68,95 +55,27 @@ public:
|
|||||||
virtual void send(const std::shared_ptr<std::string> &msgPtr) = 0;
|
virtual void send(const std::shared_ptr<std::string> &msgPtr) = 0;
|
||||||
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) = 0;
|
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) = 0;
|
||||||
|
|
||||||
/**
|
virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) = 0;
|
||||||
* @brief Send a file to the peer.
|
|
||||||
*
|
|
||||||
* @param fileName
|
|
||||||
* @param offset
|
|
||||||
* @param length
|
|
||||||
*/
|
|
||||||
virtual void sendFile(const char *fileName,
|
|
||||||
size_t offset = 0,
|
|
||||||
size_t length = 0) = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the local address of the connection.
|
|
||||||
*
|
|
||||||
* @return const InetAddress&
|
|
||||||
*/
|
|
||||||
virtual const InetAddress &localAddr() const = 0;
|
virtual const InetAddress &localAddr() const = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the remote address of the connection.
|
|
||||||
*
|
|
||||||
* @return const InetAddress&
|
|
||||||
*/
|
|
||||||
virtual const InetAddress &peerAddr() const = 0;
|
virtual const InetAddress &peerAddr() const = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return true if the connection is established.
|
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
virtual bool connected() const = 0;
|
virtual bool connected() const = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return false if the connection is established.
|
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
virtual bool disconnected() const = 0;
|
virtual bool disconnected() const = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the buffer in which the received data stored.
|
|
||||||
*
|
|
||||||
* @return MsgBuffer*
|
|
||||||
*/
|
|
||||||
virtual MsgBuffer *getRecvBuffer() = 0;
|
virtual MsgBuffer *getRecvBuffer() = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the high water mark callback
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when the data in sending buffer is
|
|
||||||
* larger than the water mark.
|
|
||||||
* @param markLen The water mark in bytes.
|
|
||||||
*/
|
|
||||||
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
|
|
||||||
size_t markLen) = 0;
|
|
||||||
|
|
||||||
/**
|
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, size_t markLen) = 0;
|
||||||
* @brief Set the TCP_NODELAY option to the socket.
|
|
||||||
*
|
|
||||||
* @param on
|
|
||||||
*/
|
|
||||||
virtual void setTcpNoDelay(bool on) = 0;
|
virtual void setTcpNoDelay(bool on) = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Shutdown the connection.
|
|
||||||
* @note This method only closes the writing direction.
|
|
||||||
*/
|
|
||||||
virtual void shutdown() = 0;
|
virtual void shutdown() = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Close the connection forcefully.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
virtual void forceClose() = 0;
|
virtual void forceClose() = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the event loop in which the connection I/O is handled.
|
|
||||||
*
|
|
||||||
* @return EventLoop*
|
|
||||||
*/
|
|
||||||
virtual EventLoop *getLoop() = 0;
|
virtual EventLoop *getLoop() = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the custom data on the connection.
|
|
||||||
*
|
|
||||||
* @param context
|
|
||||||
*/
|
|
||||||
void setContext(const std::shared_ptr<void> &context) {
|
void setContext(const std::shared_ptr<void> &context) {
|
||||||
contextPtr_ = context;
|
contextPtr_ = context;
|
||||||
}
|
}
|
||||||
@ -164,98 +83,30 @@ public:
|
|||||||
contextPtr_ = std::move(context);
|
contextPtr_ = std::move(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the custom data from the connection.
|
|
||||||
*
|
|
||||||
* @tparam T
|
|
||||||
* @return std::shared_ptr<T>
|
|
||||||
*/
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::shared_ptr<T> getContext() const {
|
std::shared_ptr<T> getContext() const {
|
||||||
return std::static_pointer_cast<T>(contextPtr_);
|
return std::static_pointer_cast<T>(contextPtr_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return true if the custom data is set by user.
|
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
bool hasContext() const {
|
bool hasContext() const {
|
||||||
return (bool)contextPtr_;
|
return (bool)contextPtr_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Clear the custom data.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void clearContext() {
|
void clearContext() {
|
||||||
contextPtr_.reset();
|
contextPtr_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Call this method to avoid being kicked off by TcpServer, refer to
|
|
||||||
* the kickoffIdleConnections method in the TcpServer class.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
virtual void keepAlive() = 0;
|
virtual void keepAlive() = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return true if the keepAlive() method is called.
|
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
virtual bool isKeepAlive() = 0;
|
virtual bool isKeepAlive() = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of bytes sent
|
|
||||||
*
|
|
||||||
* @return size_t
|
|
||||||
*/
|
|
||||||
virtual size_t bytesSent() const = 0;
|
virtual size_t bytesSent() const = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return the number of bytes received.
|
|
||||||
*
|
|
||||||
* @return size_t
|
|
||||||
*/
|
|
||||||
virtual size_t bytesReceived() const = 0;
|
virtual size_t bytesReceived() const = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Check whether the connection is SSL encrypted.
|
|
||||||
*
|
|
||||||
* @return true
|
|
||||||
* @return false
|
|
||||||
*/
|
|
||||||
virtual bool isSSLConnection() const = 0;
|
virtual bool isSSLConnection() const = 0;
|
||||||
|
|
||||||
/**
|
virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
|
||||||
* @brief Start the SSL encryption on the connection (as a client).
|
|
||||||
*
|
|
||||||
* @param callback The callback is called when the SSL connection is
|
|
||||||
* established.
|
|
||||||
* @param hostname The server hostname for SNI. If it is empty, the SNI is
|
|
||||||
* not used.
|
|
||||||
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
|
|
||||||
* OpenSSL.
|
|
||||||
*/
|
|
||||||
virtual void startClientEncryption(
|
|
||||||
std::function<void()> callback,
|
|
||||||
bool useOldTLS = false,
|
|
||||||
bool validateCert = true,
|
|
||||||
std::string hostname = "",
|
|
||||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
|
|
||||||
|
|
||||||
/**
|
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) = 0;
|
||||||
* @brief Start the SSL encryption on the connection (as a server).
|
|
||||||
*
|
|
||||||
* @param ctx The SSL context.
|
|
||||||
* @param callback The callback is called when the SSL connection is
|
|
||||||
* established.
|
|
||||||
*/
|
|
||||||
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
|
|
||||||
std::function<void()> callback) = 0;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool validateCert_ = false;
|
bool validateCert_ = false;
|
||||||
|
@ -28,20 +28,16 @@
|
|||||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
#include "core/loops/acceptor.h"
|
|
||||||
#include "core/net/connections/tcp_connection_impl.h"
|
|
||||||
#include "core/net/tcp_server.h"
|
#include "core/net/tcp_server.h"
|
||||||
#include "core/log/logger.h"
|
#include "core/log/logger.h"
|
||||||
|
#include "core/loops/acceptor.h"
|
||||||
|
#include "core/net/connections/tcp_connection_impl.h"
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std::placeholders;
|
using namespace std::placeholders;
|
||||||
|
|
||||||
TcpServer::TcpServer(EventLoop *loop,
|
TcpServer::TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr, bool reUsePort) :
|
||||||
const InetAddress &address,
|
|
||||||
const std::string &name,
|
|
||||||
bool reUseAddr,
|
|
||||||
bool reUsePort) :
|
|
||||||
loop_(loop),
|
loop_(loop),
|
||||||
acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)),
|
acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)),
|
||||||
serverName_(name),
|
serverName_(name),
|
||||||
@ -50,8 +46,8 @@ TcpServer::TcpServer(EventLoop *loop,
|
|||||||
<< " bytes]";
|
<< " bytes]";
|
||||||
buffer->retrieveAll();
|
buffer->retrieveAll();
|
||||||
}) {
|
}) {
|
||||||
acceptorPtr_->setNewConnectionCallback(
|
|
||||||
std::bind(&TcpServer::newConnection, this, _1, _2));
|
acceptorPtr_->setNewConnectionCallback(std::bind(&TcpServer::newConnection, this, _1, _2));
|
||||||
}
|
}
|
||||||
|
|
||||||
TcpServer::~TcpServer() {
|
TcpServer::~TcpServer() {
|
||||||
@ -69,14 +65,20 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
|
|||||||
// LOG_TRACE<<"vector size:"<<str.size();
|
// LOG_TRACE<<"vector size:"<<str.size();
|
||||||
// size_t n=write(sockfd,&str[0],str.size());
|
// size_t n=write(sockfd,&str[0],str.size());
|
||||||
// LOG_TRACE<<"write "<<n<<" bytes";
|
// LOG_TRACE<<"write "<<n<<" bytes";
|
||||||
|
|
||||||
loop_->assertInLoopThread();
|
loop_->assertInLoopThread();
|
||||||
|
|
||||||
EventLoop *ioLoop = NULL;
|
EventLoop *ioLoop = NULL;
|
||||||
if (loopPoolPtr_ && loopPoolPtr_->size() > 0) {
|
if (loopPoolPtr_ && loopPoolPtr_->size() > 0) {
|
||||||
ioLoop = loopPoolPtr_->getNextLoop();
|
ioLoop = loopPoolPtr_->getNextLoop();
|
||||||
}
|
}
|
||||||
if (ioLoop == NULL)
|
|
||||||
|
if (ioLoop == NULL) {
|
||||||
ioLoop = loop_;
|
ioLoop = loop_;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<TcpConnectionImpl> newPtr;
|
std::shared_ptr<TcpConnectionImpl> newPtr;
|
||||||
|
|
||||||
if (sslCtxPtr_) {
|
if (sslCtxPtr_) {
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
newPtr = std::make_shared<TcpConnectionImpl>(
|
newPtr = std::make_shared<TcpConnectionImpl>(
|
||||||
@ -90,14 +92,14 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
|
|||||||
abort();
|
abort();
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
newPtr = std::make_shared<TcpConnectionImpl>(
|
newPtr = std::make_shared<TcpConnectionImpl>(ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
|
||||||
ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (idleTimeout_ > 0) {
|
if (idleTimeout_ > 0) {
|
||||||
assert(timingWheelMap_[ioLoop]);
|
assert(timingWheelMap_[ioLoop]);
|
||||||
newPtr->enableKickingOff(idleTimeout_, timingWheelMap_[ioLoop]);
|
newPtr->enableKickingOff(idleTimeout_, timingWheelMap_[ioLoop]);
|
||||||
}
|
}
|
||||||
|
|
||||||
newPtr->setRecvMsgCallback(recvMessageCallback_);
|
newPtr->setRecvMsgCallback(recvMessageCallback_);
|
||||||
|
|
||||||
newPtr->setConnectionCallback(
|
newPtr->setConnectionCallback(
|
||||||
@ -105,11 +107,13 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
|
|||||||
if (connectionCallback_)
|
if (connectionCallback_)
|
||||||
connectionCallback_(connectionPtr);
|
connectionCallback_(connectionPtr);
|
||||||
});
|
});
|
||||||
|
|
||||||
newPtr->setWriteCompleteCallback(
|
newPtr->setWriteCompleteCallback(
|
||||||
[this](const TcpConnectionPtr &connectionPtr) {
|
[this](const TcpConnectionPtr &connectionPtr) {
|
||||||
if (writeCompleteCallback_)
|
if (writeCompleteCallback_)
|
||||||
writeCompleteCallback_(connectionPtr);
|
writeCompleteCallback_(connectionPtr);
|
||||||
});
|
});
|
||||||
|
|
||||||
newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1));
|
newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1));
|
||||||
connSet_.insert(newPtr);
|
connSet_.insert(newPtr);
|
||||||
newPtr->connectEstablished();
|
newPtr->connectEstablished();
|
||||||
@ -143,6 +147,7 @@ void TcpServer::start() {
|
|||||||
acceptorPtr_->listen();
|
acceptorPtr_->listen();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void TcpServer::stop() {
|
void TcpServer::stop() {
|
||||||
loop_->runInLoop([this]() { acceptorPtr_.reset(); });
|
loop_->runInLoop([this]() { acceptorPtr_.reset(); });
|
||||||
for (auto connection : connSet_) {
|
for (auto connection : connSet_) {
|
||||||
@ -159,6 +164,7 @@ void TcpServer::stop() {
|
|||||||
f.get();
|
f.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) {
|
void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) {
|
||||||
LOG_TRACE << "connectionClosed";
|
LOG_TRACE << "connectionClosed";
|
||||||
// loop_->assertInLoopThread();
|
// loop_->assertInLoopThread();
|
||||||
@ -179,11 +185,7 @@ const InetAddress &TcpServer::address() const {
|
|||||||
return acceptorPtr_->addr();
|
return acceptorPtr_->addr();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TcpServer::enableSSL(
|
void TcpServer::enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
||||||
const std::string &certPath,
|
|
||||||
const std::string &keyPath,
|
|
||||||
bool useOldTLS,
|
|
||||||
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
|
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
/* Create a new OpenSSL context */
|
/* Create a new OpenSSL context */
|
||||||
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds);
|
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds);
|
||||||
|
@ -43,10 +43,7 @@
|
|||||||
|
|
||||||
class Acceptor;
|
class Acceptor;
|
||||||
class SSLContext;
|
class SSLContext;
|
||||||
/**
|
|
||||||
* @brief This class represents a TCP server.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
class TcpServer {
|
class TcpServer {
|
||||||
protected:
|
protected:
|
||||||
TcpServer(const TcpServer &) = delete;
|
TcpServer(const TcpServer &) = delete;
|
||||||
@ -56,53 +53,18 @@ protected:
|
|||||||
TcpServer &operator=(TcpServer &&) noexcept(true) = default;
|
TcpServer &operator=(TcpServer &&) noexcept(true) = default;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr = true, bool reUsePort = true);
|
||||||
* @brief Construct a new TCP server instance.
|
|
||||||
*
|
|
||||||
* @param loop The event loop in which the acceptor of the server is
|
|
||||||
* handled.
|
|
||||||
* @param address The address of the server.
|
|
||||||
* @param name The name of the server.
|
|
||||||
* @param reUseAddr The SO_REUSEADDR option.
|
|
||||||
* @param reUsePort The SO_REUSEPORT option.
|
|
||||||
*/
|
|
||||||
TcpServer(EventLoop *loop,
|
|
||||||
const InetAddress &address,
|
|
||||||
const std::string &name,
|
|
||||||
bool reUseAddr = true,
|
|
||||||
bool reUsePort = true);
|
|
||||||
~TcpServer();
|
~TcpServer();
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Start the server.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void start();
|
void start();
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Stop the server.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void stop();
|
void stop();
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the number of event loops in which the I/O of connections to
|
|
||||||
* the server is handled.
|
|
||||||
*
|
|
||||||
* @param num
|
|
||||||
*/
|
|
||||||
void setIoLoopNum(size_t num) {
|
void setIoLoopNum(size_t num) {
|
||||||
assert(!started_);
|
assert(!started_);
|
||||||
loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num);
|
loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num);
|
||||||
loopPoolPtr_->start();
|
loopPoolPtr_->start();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the event loops pool in which the I/O of connections to
|
|
||||||
* the server is handled.
|
|
||||||
*
|
|
||||||
* @param pool
|
|
||||||
*/
|
|
||||||
void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) {
|
void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) {
|
||||||
assert(pool->size() > 0);
|
assert(pool->size() > 0);
|
||||||
assert(!started_);
|
assert(!started_);
|
||||||
@ -110,12 +72,6 @@ public:
|
|||||||
loopPoolPtr_->start();
|
loopPoolPtr_->start();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the message callback.
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when some data is received on a
|
|
||||||
* connection to the server.
|
|
||||||
*/
|
|
||||||
void setRecvMessageCallback(const RecvMessageCallback &cb) {
|
void setRecvMessageCallback(const RecvMessageCallback &cb) {
|
||||||
recvMessageCallback_ = cb;
|
recvMessageCallback_ = cb;
|
||||||
}
|
}
|
||||||
@ -123,12 +79,6 @@ public:
|
|||||||
recvMessageCallback_ = std::move(cb);
|
recvMessageCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the connection callback.
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when a connection is established or
|
|
||||||
* closed.
|
|
||||||
*/
|
|
||||||
void setConnectionCallback(const ConnectionCallback &cb) {
|
void setConnectionCallback(const ConnectionCallback &cb) {
|
||||||
connectionCallback_ = cb;
|
connectionCallback_ = cb;
|
||||||
}
|
}
|
||||||
@ -136,12 +86,6 @@ public:
|
|||||||
connectionCallback_ = std::move(cb);
|
connectionCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Set the write complete callback.
|
|
||||||
*
|
|
||||||
* @param cb The callback is called when data to send is written to the
|
|
||||||
* socket of a connection.
|
|
||||||
*/
|
|
||||||
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
|
||||||
writeCompleteCallback_ = cb;
|
writeCompleteCallback_ = cb;
|
||||||
}
|
}
|
||||||
@ -149,53 +93,21 @@ public:
|
|||||||
writeCompleteCallback_ = std::move(cb);
|
writeCompleteCallback_ = std::move(cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the name of the server.
|
|
||||||
*
|
|
||||||
* @return const std::string&
|
|
||||||
*/
|
|
||||||
const std::string &name() const {
|
const std::string &name() const {
|
||||||
return serverName_;
|
return serverName_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the IP and port string of the server.
|
|
||||||
*
|
|
||||||
* @return const std::string
|
|
||||||
*/
|
|
||||||
const std::string ipPort() const;
|
const std::string ipPort() const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the address of the server.
|
|
||||||
*
|
|
||||||
* @return const InetAddress&
|
|
||||||
*/
|
|
||||||
const InetAddress &address() const;
|
const InetAddress &address() const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the event loop of the server.
|
|
||||||
*
|
|
||||||
* @return EventLoop*
|
|
||||||
*/
|
|
||||||
EventLoop *getLoop() const {
|
EventLoop *getLoop() const {
|
||||||
return loop_;
|
return loop_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the I/O event loops of the server.
|
|
||||||
*
|
|
||||||
* @return std::vector<EventLoop *>
|
|
||||||
*/
|
|
||||||
std::vector<EventLoop *> getIoLoops() const {
|
std::vector<EventLoop *> getIoLoops() const {
|
||||||
return loopPoolPtr_->getLoops();
|
return loopPoolPtr_->getLoops();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief An idle connection is a connection that has no read or write, kick
|
|
||||||
* off it after timeout seconds.
|
|
||||||
*
|
|
||||||
* @param timeout
|
|
||||||
*/
|
|
||||||
void kickoffIdleConnections(size_t timeout) {
|
void kickoffIdleConnections(size_t timeout) {
|
||||||
loop_->runInLoop([this, timeout]() {
|
loop_->runInLoop([this, timeout]() {
|
||||||
assert(!started_);
|
assert(!started_);
|
||||||
@ -203,23 +115,12 @@ public:
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// certPath The path of the certificate file.
|
||||||
* @brief Enable SSL encryption.
|
// keyPath The path of the private key file.
|
||||||
*
|
// useOldTLS If true, the TLS 1.0 and 1.1 are supported by the server.
|
||||||
* @param certPath The path of the certificate file.
|
// sslConfCmds The commands used to call the SSL_CONF_cmd function in OpenSSL.
|
||||||
* @param keyPath The path of the private key file.
|
// Note: It's well known that TLS 1.0 and 1.1 are not considered secure in 2020. And it's a good practice to only use TLS 1.2 and above.
|
||||||
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
|
void enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
|
||||||
* server.
|
|
||||||
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
|
|
||||||
* OpenSSL.
|
|
||||||
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
|
|
||||||
* 2020. And it's a good practice to only use TLS 1.2 and above.
|
|
||||||
*/
|
|
||||||
void enableSSL(const std::string &certPath,
|
|
||||||
const std::string &keyPath,
|
|
||||||
bool useOldTLS = false,
|
|
||||||
const std::vector<std::pair<std::string, std::string> >
|
|
||||||
&sslConfCmds = {});
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EventLoop *loop_;
|
EventLoop *loop_;
|
||||||
@ -236,6 +137,7 @@ private:
|
|||||||
std::map<EventLoop *, std::shared_ptr<TimingWheel> > timingWheelMap_;
|
std::map<EventLoop *, std::shared_ptr<TimingWheel> > timingWheelMap_;
|
||||||
void connectionClosed(const TcpConnectionPtr &connectionPtr);
|
void connectionClosed(const TcpConnectionPtr &connectionPtr);
|
||||||
std::shared_ptr<EventLoopThreadPool> loopPoolPtr_;
|
std::shared_ptr<EventLoopThreadPool> loopPoolPtr_;
|
||||||
|
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
class IgnoreSigPipe {
|
class IgnoreSigPipe {
|
||||||
public:
|
public:
|
||||||
@ -247,6 +149,7 @@ private:
|
|||||||
|
|
||||||
IgnoreSigPipe initObj;
|
IgnoreSigPipe initObj;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool started_{ false };
|
bool started_{ false };
|
||||||
|
|
||||||
// OpenSSL SSL context Object;
|
// OpenSSL SSL context Object;
|
||||||
|
Loading…
Reference in New Issue
Block a user