mirror of
https://github.com/Relintai/rcpp_framework.git
synced 2024-11-14 04:57:21 +01:00
345 lines
9.6 KiB
C
345 lines
9.6 KiB
C
|
/**
|
||
|
*
|
||
|
* @file TcpConnectionImpl.h
|
||
|
* @author An Tao
|
||
|
*
|
||
|
* Public header file in trantor lib.
|
||
|
*
|
||
|
* Copyright 2018, An Tao. All rights reserved.
|
||
|
* Use of this source code is governed by a BSD-style license
|
||
|
* that can be found in the License file.
|
||
|
*
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include <trantor/net/TcpConnection.h>
|
||
|
#include <trantor/utils/TimingWheel.h>
|
||
|
#include <list>
|
||
|
#include <mutex>
|
||
|
#ifndef _WIN32
|
||
|
#include <unistd.h>
|
||
|
#endif
|
||
|
#include <thread>
|
||
|
#include <array>
|
||
|
|
||
|
namespace trantor
|
||
|
{
|
||
|
#ifdef USE_OPENSSL
|
||
|
enum class SSLStatus
|
||
|
{
|
||
|
Handshaking,
|
||
|
Connecting,
|
||
|
Connected,
|
||
|
DisConnecting,
|
||
|
DisConnected
|
||
|
};
|
||
|
class SSLContext;
|
||
|
class SSLConn;
|
||
|
|
||
|
std::shared_ptr<SSLContext> newSSLContext(
|
||
|
bool useOldTLS,
|
||
|
bool validateCert,
|
||
|
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
||
|
std::shared_ptr<SSLContext> newSSLServerContext(
|
||
|
const std::string &certPath,
|
||
|
const std::string &keyPath,
|
||
|
bool useOldTLS,
|
||
|
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
||
|
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
|
||
|
// const std::string &certPath,
|
||
|
// const std::string &keyPath);
|
||
|
#endif
|
||
|
class Channel;
|
||
|
class Socket;
|
||
|
class TcpServer;
|
||
|
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
|
||
|
class TcpConnectionImpl : public TcpConnection,
|
||
|
public NonCopyable,
|
||
|
public std::enable_shared_from_this<TcpConnectionImpl>
|
||
|
{
|
||
|
friend class TcpServer;
|
||
|
friend class TcpClient;
|
||
|
friend void trantor::removeConnection(EventLoop *loop,
|
||
|
const TcpConnectionPtr &conn);
|
||
|
|
||
|
public:
|
||
|
class KickoffEntry
|
||
|
{
|
||
|
public:
|
||
|
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn)
|
||
|
: conn_(conn)
|
||
|
{
|
||
|
}
|
||
|
void reset()
|
||
|
{
|
||
|
conn_.reset();
|
||
|
}
|
||
|
~KickoffEntry()
|
||
|
{
|
||
|
auto conn = conn_.lock();
|
||
|
if (conn)
|
||
|
{
|
||
|
conn->forceClose();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
std::weak_ptr<TcpConnection> conn_;
|
||
|
};
|
||
|
|
||
|
TcpConnectionImpl(EventLoop *loop,
|
||
|
int socketfd,
|
||
|
const InetAddress &localAddr,
|
||
|
const InetAddress &peerAddr);
|
||
|
#ifdef USE_OPENSSL
|
||
|
TcpConnectionImpl(EventLoop *loop,
|
||
|
int socketfd,
|
||
|
const InetAddress &localAddr,
|
||
|
const InetAddress &peerAddr,
|
||
|
const std::shared_ptr<SSLContext> &ctxPtr,
|
||
|
bool isServer = true,
|
||
|
bool validateCert = true,
|
||
|
const std::string &hostname = "");
|
||
|
#endif
|
||
|
virtual ~TcpConnectionImpl();
|
||
|
virtual void send(const char *msg, size_t len) override;
|
||
|
virtual void send(const void *msg, size_t len) override;
|
||
|
virtual void send(const std::string &msg) override;
|
||
|
virtual void send(std::string &&msg) override;
|
||
|
virtual void send(const MsgBuffer &buffer) override;
|
||
|
virtual void send(MsgBuffer &&buffer) override;
|
||
|
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
|
||
|
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
|
||
|
virtual void sendFile(const char *fileName,
|
||
|
size_t offset = 0,
|
||
|
size_t length = 0) override;
|
||
|
|
||
|
virtual const InetAddress &localAddr() const override
|
||
|
{
|
||
|
return localAddr_;
|
||
|
}
|
||
|
virtual const InetAddress &peerAddr() const override
|
||
|
{
|
||
|
return peerAddr_;
|
||
|
}
|
||
|
|
||
|
virtual bool connected() const override
|
||
|
{
|
||
|
return status_ == ConnStatus::Connected;
|
||
|
}
|
||
|
virtual bool disconnected() const override
|
||
|
{
|
||
|
return status_ == ConnStatus::Disconnected;
|
||
|
}
|
||
|
|
||
|
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
|
||
|
virtual MsgBuffer *getRecvBuffer() override
|
||
|
{
|
||
|
return &readBuffer_;
|
||
|
}
|
||
|
// set callbacks
|
||
|
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
|
||
|
size_t markLen) override
|
||
|
{
|
||
|
highWaterMarkCallback_ = cb;
|
||
|
highWaterMarkLen_ = markLen;
|
||
|
}
|
||
|
|
||
|
virtual void keepAlive() override
|
||
|
{
|
||
|
idleTimeout_ = 0;
|
||
|
auto entry = kickoffEntry_.lock();
|
||
|
if (entry)
|
||
|
{
|
||
|
entry->reset();
|
||
|
}
|
||
|
}
|
||
|
virtual bool isKeepAlive() override
|
||
|
{
|
||
|
return idleTimeout_ == 0;
|
||
|
}
|
||
|
virtual void setTcpNoDelay(bool on) override;
|
||
|
virtual void shutdown() override;
|
||
|
virtual void forceClose() override;
|
||
|
virtual EventLoop *getLoop() override
|
||
|
{
|
||
|
return loop_;
|
||
|
}
|
||
|
|
||
|
virtual size_t bytesSent() const override
|
||
|
{
|
||
|
return bytesSent_;
|
||
|
}
|
||
|
virtual size_t bytesReceived() const override
|
||
|
{
|
||
|
return bytesReceived_;
|
||
|
}
|
||
|
virtual void startClientEncryption(
|
||
|
std::function<void()> callback,
|
||
|
bool useOldTLS = false,
|
||
|
bool validateCert = true,
|
||
|
std::string hostname = "",
|
||
|
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
|
||
|
{}) override;
|
||
|
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
|
||
|
std::function<void()> callback) override;
|
||
|
virtual bool isSSLConnection() const override
|
||
|
{
|
||
|
return isEncrypted_;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
/// Internal use only.
|
||
|
|
||
|
std::weak_ptr<KickoffEntry> kickoffEntry_;
|
||
|
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
|
||
|
size_t idleTimeout_{0};
|
||
|
Date lastTimingWheelUpdateTime_;
|
||
|
|
||
|
void enableKickingOff(size_t timeout,
|
||
|
const std::shared_ptr<TimingWheel> &timingWheel)
|
||
|
{
|
||
|
assert(timingWheel);
|
||
|
assert(timingWheel->getLoop() == loop_);
|
||
|
assert(timeout > 0);
|
||
|
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
|
||
|
kickoffEntry_ = entry;
|
||
|
timingWheelWeakPtr_ = timingWheel;
|
||
|
idleTimeout_ = timeout;
|
||
|
timingWheel->insertEntry(timeout, entry);
|
||
|
}
|
||
|
void extendLife();
|
||
|
#ifndef _WIN32
|
||
|
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
|
||
|
#else
|
||
|
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
|
||
|
#endif
|
||
|
void setRecvMsgCallback(const RecvMessageCallback &cb)
|
||
|
{
|
||
|
recvMsgCallback_ = cb;
|
||
|
}
|
||
|
void setConnectionCallback(const ConnectionCallback &cb)
|
||
|
{
|
||
|
connectionCallback_ = cb;
|
||
|
}
|
||
|
void setWriteCompleteCallback(const WriteCompleteCallback &cb)
|
||
|
{
|
||
|
writeCompleteCallback_ = cb;
|
||
|
}
|
||
|
void setCloseCallback(const CloseCallback &cb)
|
||
|
{
|
||
|
closeCallback_ = cb;
|
||
|
}
|
||
|
void setSSLErrorCallback(const SSLErrorCallback &cb)
|
||
|
{
|
||
|
sslErrorCallback_ = cb;
|
||
|
}
|
||
|
|
||
|
void connectDestroyed();
|
||
|
virtual void connectEstablished();
|
||
|
|
||
|
protected:
|
||
|
struct BufferNode
|
||
|
{
|
||
|
#ifndef _WIN32
|
||
|
int sendFd_{-1};
|
||
|
off_t offset_;
|
||
|
#else
|
||
|
FILE *sendFp_{nullptr};
|
||
|
long long offset_;
|
||
|
#endif
|
||
|
ssize_t fileBytesToSend_;
|
||
|
std::shared_ptr<MsgBuffer> msgBuffer_;
|
||
|
~BufferNode()
|
||
|
{
|
||
|
#ifndef _WIN32
|
||
|
if (sendFd_ >= 0)
|
||
|
close(sendFd_);
|
||
|
#else
|
||
|
if (sendFp_)
|
||
|
fclose(sendFp_);
|
||
|
#endif
|
||
|
}
|
||
|
};
|
||
|
using BufferNodePtr = std::shared_ptr<BufferNode>;
|
||
|
enum class ConnStatus
|
||
|
{
|
||
|
Disconnected,
|
||
|
Connecting,
|
||
|
Connected,
|
||
|
Disconnecting
|
||
|
};
|
||
|
bool isEncrypted_{false};
|
||
|
EventLoop *loop_;
|
||
|
std::unique_ptr<Channel> ioChannelPtr_;
|
||
|
std::unique_ptr<Socket> socketPtr_;
|
||
|
MsgBuffer readBuffer_;
|
||
|
std::list<BufferNodePtr> writeBufferList_;
|
||
|
void readCallback();
|
||
|
void writeCallback();
|
||
|
InetAddress localAddr_, peerAddr_;
|
||
|
ConnStatus status_{ConnStatus::Connecting};
|
||
|
// callbacks
|
||
|
RecvMessageCallback recvMsgCallback_;
|
||
|
ConnectionCallback connectionCallback_;
|
||
|
CloseCallback closeCallback_;
|
||
|
WriteCompleteCallback writeCompleteCallback_;
|
||
|
HighWaterMarkCallback highWaterMarkCallback_;
|
||
|
SSLErrorCallback sslErrorCallback_;
|
||
|
void handleClose();
|
||
|
void handleError();
|
||
|
// virtual void sendInLoop(const std::string &msg);
|
||
|
|
||
|
void sendFileInLoop(const BufferNodePtr &file);
|
||
|
#ifndef _WIN32
|
||
|
void sendInLoop(const void *buffer, size_t length);
|
||
|
ssize_t writeInLoop(const void *buffer, size_t length);
|
||
|
#else
|
||
|
void sendInLoop(const char *buffer, size_t length);
|
||
|
ssize_t writeInLoop(const char *buffer, size_t length);
|
||
|
#endif
|
||
|
size_t highWaterMarkLen_;
|
||
|
std::string name_;
|
||
|
|
||
|
uint64_t sendNum_{0};
|
||
|
std::mutex sendNumMutex_;
|
||
|
|
||
|
size_t bytesSent_{0};
|
||
|
size_t bytesReceived_{0};
|
||
|
|
||
|
std::unique_ptr<std::vector<char>> fileBufferPtr_;
|
||
|
|
||
|
#ifdef USE_OPENSSL
|
||
|
private:
|
||
|
void doHandshaking();
|
||
|
bool validatePeerCertificate();
|
||
|
struct SSLEncryption
|
||
|
{
|
||
|
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
|
||
|
// OpenSSL
|
||
|
std::shared_ptr<SSLContext> sslCtxPtr_;
|
||
|
std::unique_ptr<SSLConn> sslPtr_;
|
||
|
std::unique_ptr<std::array<char, 8192>> sendBufferPtr_;
|
||
|
bool isServer_{false};
|
||
|
bool isUpgrade_{false};
|
||
|
std::function<void()> upgradeCallback_;
|
||
|
std::string hostname_;
|
||
|
};
|
||
|
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
|
||
|
void startClientEncryptionInLoop(
|
||
|
std::function<void()> &&callback,
|
||
|
bool useOldTLS,
|
||
|
bool validateCert,
|
||
|
const std::string &hostname,
|
||
|
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
|
||
|
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
|
||
|
std::function<void()> &&callback);
|
||
|
#endif
|
||
|
};
|
||
|
|
||
|
using TcpConnectionImplPtr = std::shared_ptr<TcpConnectionImpl>;
|
||
|
|
||
|
} // namespace trantor
|