rcpp_framework/modules/drogon/trantor/net/inner/TcpConnectionImpl.h

345 lines
9.6 KiB
C++

/**
*
* @file TcpConnectionImpl.h
* @author An Tao
*
* Public header file in trantor lib.
*
* Copyright 2018, An Tao. All rights reserved.
* Use of this source code is governed by a BSD-style license
* that can be found in the License file.
*
*
*/
#pragma once
#include <trantor/net/TcpConnection.h>
#include <trantor/utils/TimingWheel.h>
#include <list>
#include <mutex>
#ifndef _WIN32
#include <unistd.h>
#endif
#include <thread>
#include <array>
namespace trantor
{
#ifdef USE_OPENSSL
enum class SSLStatus
{
Handshaking,
Connecting,
Connected,
DisConnecting,
DisConnected
};
class SSLContext;
class SSLConn;
std::shared_ptr<SSLContext> newSSLContext(
bool useOldTLS,
bool validateCert,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
std::shared_ptr<SSLContext> newSSLServerContext(
const std::string &certPath,
const std::string &keyPath,
bool useOldTLS,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
// const std::string &certPath,
// const std::string &keyPath);
#endif
class Channel;
class Socket;
class TcpServer;
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
class TcpConnectionImpl : public TcpConnection,
public NonCopyable,
public std::enable_shared_from_this<TcpConnectionImpl>
{
friend class TcpServer;
friend class TcpClient;
friend void trantor::removeConnection(EventLoop *loop,
const TcpConnectionPtr &conn);
public:
class KickoffEntry
{
public:
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn)
: conn_(conn)
{
}
void reset()
{
conn_.reset();
}
~KickoffEntry()
{
auto conn = conn_.lock();
if (conn)
{
conn->forceClose();
}
}
private:
std::weak_ptr<TcpConnection> conn_;
};
TcpConnectionImpl(EventLoop *loop,
int socketfd,
const InetAddress &localAddr,
const InetAddress &peerAddr);
#ifdef USE_OPENSSL
TcpConnectionImpl(EventLoop *loop,
int socketfd,
const InetAddress &localAddr,
const InetAddress &peerAddr,
const std::shared_ptr<SSLContext> &ctxPtr,
bool isServer = true,
bool validateCert = true,
const std::string &hostname = "");
#endif
virtual ~TcpConnectionImpl();
virtual void send(const char *msg, size_t len) override;
virtual void send(const void *msg, size_t len) override;
virtual void send(const std::string &msg) override;
virtual void send(std::string &&msg) override;
virtual void send(const MsgBuffer &buffer) override;
virtual void send(MsgBuffer &&buffer) override;
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
virtual void sendFile(const char *fileName,
size_t offset = 0,
size_t length = 0) override;
virtual const InetAddress &localAddr() const override
{
return localAddr_;
}
virtual const InetAddress &peerAddr() const override
{
return peerAddr_;
}
virtual bool connected() const override
{
return status_ == ConnStatus::Connected;
}
virtual bool disconnected() const override
{
return status_ == ConnStatus::Disconnected;
}
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
virtual MsgBuffer *getRecvBuffer() override
{
return &readBuffer_;
}
// set callbacks
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
size_t markLen) override
{
highWaterMarkCallback_ = cb;
highWaterMarkLen_ = markLen;
}
virtual void keepAlive() override
{
idleTimeout_ = 0;
auto entry = kickoffEntry_.lock();
if (entry)
{
entry->reset();
}
}
virtual bool isKeepAlive() override
{
return idleTimeout_ == 0;
}
virtual void setTcpNoDelay(bool on) override;
virtual void shutdown() override;
virtual void forceClose() override;
virtual EventLoop *getLoop() override
{
return loop_;
}
virtual size_t bytesSent() const override
{
return bytesSent_;
}
virtual size_t bytesReceived() const override
{
return bytesReceived_;
}
virtual void startClientEncryption(
std::function<void()> callback,
bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
{}) override;
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> callback) override;
virtual bool isSSLConnection() const override
{
return isEncrypted_;
}
private:
/// Internal use only.
std::weak_ptr<KickoffEntry> kickoffEntry_;
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
size_t idleTimeout_{0};
Date lastTimingWheelUpdateTime_;
void enableKickingOff(size_t timeout,
const std::shared_ptr<TimingWheel> &timingWheel)
{
assert(timingWheel);
assert(timingWheel->getLoop() == loop_);
assert(timeout > 0);
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
kickoffEntry_ = entry;
timingWheelWeakPtr_ = timingWheel;
idleTimeout_ = timeout;
timingWheel->insertEntry(timeout, entry);
}
void extendLife();
#ifndef _WIN32
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
#else
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
#endif
void setRecvMsgCallback(const RecvMessageCallback &cb)
{
recvMsgCallback_ = cb;
}
void setConnectionCallback(const ConnectionCallback &cb)
{
connectionCallback_ = cb;
}
void setWriteCompleteCallback(const WriteCompleteCallback &cb)
{
writeCompleteCallback_ = cb;
}
void setCloseCallback(const CloseCallback &cb)
{
closeCallback_ = cb;
}
void setSSLErrorCallback(const SSLErrorCallback &cb)
{
sslErrorCallback_ = cb;
}
void connectDestroyed();
virtual void connectEstablished();
protected:
struct BufferNode
{
#ifndef _WIN32
int sendFd_{-1};
off_t offset_;
#else
FILE *sendFp_{nullptr};
long long offset_;
#endif
ssize_t fileBytesToSend_;
std::shared_ptr<MsgBuffer> msgBuffer_;
~BufferNode()
{
#ifndef _WIN32
if (sendFd_ >= 0)
close(sendFd_);
#else
if (sendFp_)
fclose(sendFp_);
#endif
}
};
using BufferNodePtr = std::shared_ptr<BufferNode>;
enum class ConnStatus
{
Disconnected,
Connecting,
Connected,
Disconnecting
};
bool isEncrypted_{false};
EventLoop *loop_;
std::unique_ptr<Channel> ioChannelPtr_;
std::unique_ptr<Socket> socketPtr_;
MsgBuffer readBuffer_;
std::list<BufferNodePtr> writeBufferList_;
void readCallback();
void writeCallback();
InetAddress localAddr_, peerAddr_;
ConnStatus status_{ConnStatus::Connecting};
// callbacks
RecvMessageCallback recvMsgCallback_;
ConnectionCallback connectionCallback_;
CloseCallback closeCallback_;
WriteCompleteCallback writeCompleteCallback_;
HighWaterMarkCallback highWaterMarkCallback_;
SSLErrorCallback sslErrorCallback_;
void handleClose();
void handleError();
// virtual void sendInLoop(const std::string &msg);
void sendFileInLoop(const BufferNodePtr &file);
#ifndef _WIN32
void sendInLoop(const void *buffer, size_t length);
ssize_t writeInLoop(const void *buffer, size_t length);
#else
void sendInLoop(const char *buffer, size_t length);
ssize_t writeInLoop(const char *buffer, size_t length);
#endif
size_t highWaterMarkLen_;
std::string name_;
uint64_t sendNum_{0};
std::mutex sendNumMutex_;
size_t bytesSent_{0};
size_t bytesReceived_{0};
std::unique_ptr<std::vector<char>> fileBufferPtr_;
#ifdef USE_OPENSSL
private:
void doHandshaking();
bool validatePeerCertificate();
struct SSLEncryption
{
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
// OpenSSL
std::shared_ptr<SSLContext> sslCtxPtr_;
std::unique_ptr<SSLConn> sslPtr_;
std::unique_ptr<std::array<char, 8192>> sendBufferPtr_;
bool isServer_{false};
bool isUpgrade_{false};
std::function<void()> upgradeCallback_;
std::string hostname_;
};
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
void startClientEncryptionInLoop(
std::function<void()> &&callback,
bool useOldTLS,
bool validateCert,
const std::string &hostname,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> &&callback);
#endif
};
using TcpConnectionImplPtr = std::shared_ptr<TcpConnectionImpl>;
} // namespace trantor