/** * * @file TcpConnectionImpl.h * @author An Tao * * Public header file in trantor lib. * * Copyright 2018, An Tao. All rights reserved. * Use of this source code is governed by a BSD-style license * that can be found in the License file. * * */ #pragma once #include <trantor/net/TcpConnection.h> #include <trantor/utils/TimingWheel.h> #include <list> #include <mutex> #ifndef _WIN32 #include <unistd.h> #endif #include <thread> #include <array> namespace trantor { #ifdef USE_OPENSSL enum class SSLStatus { Handshaking, Connecting, Connected, DisConnecting, DisConnected }; class SSLContext; class SSLConn; std::shared_ptr<SSLContext> newSSLContext( bool useOldTLS, bool validateCert, const std::vector<std::pair<std::string, std::string>> &sslConfCmds); std::shared_ptr<SSLContext> newSSLServerContext( const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string>> &sslConfCmds); // void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx, // const std::string &certPath, // const std::string &keyPath); #endif class Channel; class Socket; class TcpServer; void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn); class TcpConnectionImpl : public TcpConnection, public NonCopyable, public std::enable_shared_from_this<TcpConnectionImpl> { friend class TcpServer; friend class TcpClient; friend void trantor::removeConnection(EventLoop *loop, const TcpConnectionPtr &conn); public: class KickoffEntry { public: explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn) : conn_(conn) { } void reset() { conn_.reset(); } ~KickoffEntry() { auto conn = conn_.lock(); if (conn) { conn->forceClose(); } } private: std::weak_ptr<TcpConnection> conn_; }; TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr); #ifdef USE_OPENSSL TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, const std::shared_ptr<SSLContext> &ctxPtr, bool isServer = true, bool validateCert = true, const std::string &hostname = ""); #endif virtual ~TcpConnectionImpl(); virtual void send(const char *msg, size_t len) override; virtual void send(const void *msg, size_t len) override; virtual void send(const std::string &msg) override; virtual void send(std::string &&msg) override; virtual void send(const MsgBuffer &buffer) override; virtual void send(MsgBuffer &&buffer) override; virtual void send(const std::shared_ptr<std::string> &msgPtr) override; virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override; virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) override; virtual const InetAddress &localAddr() const override { return localAddr_; } virtual const InetAddress &peerAddr() const override { return peerAddr_; } virtual bool connected() const override { return status_ == ConnStatus::Connected; } virtual bool disconnected() const override { return status_ == ConnStatus::Disconnected; } // virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;} virtual MsgBuffer *getRecvBuffer() override { return &readBuffer_; } // set callbacks virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, size_t markLen) override { highWaterMarkCallback_ = cb; highWaterMarkLen_ = markLen; } virtual void keepAlive() override { idleTimeout_ = 0; auto entry = kickoffEntry_.lock(); if (entry) { entry->reset(); } } virtual bool isKeepAlive() override { return idleTimeout_ == 0; } virtual void setTcpNoDelay(bool on) override; virtual void shutdown() override; virtual void forceClose() override; virtual EventLoop *getLoop() override { return loop_; } virtual size_t bytesSent() const override { return bytesSent_; } virtual size_t bytesReceived() const override { return bytesReceived_; } virtual void startClientEncryption( std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string>> &sslConfCmds = {}) override; virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) override; virtual bool isSSLConnection() const override { return isEncrypted_; } private: /// Internal use only. std::weak_ptr<KickoffEntry> kickoffEntry_; std::weak_ptr<TimingWheel> timingWheelWeakPtr_; size_t idleTimeout_{0}; Date lastTimingWheelUpdateTime_; void enableKickingOff(size_t timeout, const std::shared_ptr<TimingWheel> &timingWheel) { assert(timingWheel); assert(timingWheel->getLoop() == loop_); assert(timeout > 0); auto entry = std::make_shared<KickoffEntry>(shared_from_this()); kickoffEntry_ = entry; timingWheelWeakPtr_ = timingWheel; idleTimeout_ = timeout; timingWheel->insertEntry(timeout, entry); } void extendLife(); #ifndef _WIN32 void sendFile(int sfd, size_t offset = 0, size_t length = 0); #else void sendFile(FILE *fp, size_t offset = 0, size_t length = 0); #endif void setRecvMsgCallback(const RecvMessageCallback &cb) { recvMsgCallback_ = cb; } void setConnectionCallback(const ConnectionCallback &cb) { connectionCallback_ = cb; } void setWriteCompleteCallback(const WriteCompleteCallback &cb) { writeCompleteCallback_ = cb; } void setCloseCallback(const CloseCallback &cb) { closeCallback_ = cb; } void setSSLErrorCallback(const SSLErrorCallback &cb) { sslErrorCallback_ = cb; } void connectDestroyed(); virtual void connectEstablished(); protected: struct BufferNode { #ifndef _WIN32 int sendFd_{-1}; off_t offset_; #else FILE *sendFp_{nullptr}; long long offset_; #endif ssize_t fileBytesToSend_; std::shared_ptr<MsgBuffer> msgBuffer_; ~BufferNode() { #ifndef _WIN32 if (sendFd_ >= 0) close(sendFd_); #else if (sendFp_) fclose(sendFp_); #endif } }; using BufferNodePtr = std::shared_ptr<BufferNode>; enum class ConnStatus { Disconnected, Connecting, Connected, Disconnecting }; bool isEncrypted_{false}; EventLoop *loop_; std::unique_ptr<Channel> ioChannelPtr_; std::unique_ptr<Socket> socketPtr_; MsgBuffer readBuffer_; std::list<BufferNodePtr> writeBufferList_; void readCallback(); void writeCallback(); InetAddress localAddr_, peerAddr_; ConnStatus status_{ConnStatus::Connecting}; // callbacks RecvMessageCallback recvMsgCallback_; ConnectionCallback connectionCallback_; CloseCallback closeCallback_; WriteCompleteCallback writeCompleteCallback_; HighWaterMarkCallback highWaterMarkCallback_; SSLErrorCallback sslErrorCallback_; void handleClose(); void handleError(); // virtual void sendInLoop(const std::string &msg); void sendFileInLoop(const BufferNodePtr &file); #ifndef _WIN32 void sendInLoop(const void *buffer, size_t length); ssize_t writeInLoop(const void *buffer, size_t length); #else void sendInLoop(const char *buffer, size_t length); ssize_t writeInLoop(const char *buffer, size_t length); #endif size_t highWaterMarkLen_; std::string name_; uint64_t sendNum_{0}; std::mutex sendNumMutex_; size_t bytesSent_{0}; size_t bytesReceived_{0}; std::unique_ptr<std::vector<char>> fileBufferPtr_; #ifdef USE_OPENSSL private: void doHandshaking(); bool validatePeerCertificate(); struct SSLEncryption { SSLStatus statusOfSSL_ = SSLStatus::Handshaking; // OpenSSL std::shared_ptr<SSLContext> sslCtxPtr_; std::unique_ptr<SSLConn> sslPtr_; std::unique_ptr<std::array<char, 8192>> sendBufferPtr_; bool isServer_{false}; bool isUpgrade_{false}; std::function<void()> upgradeCallback_; std::string hostname_; }; std::unique_ptr<SSLEncryption> sslEncryptionPtr_; void startClientEncryptionInLoop( std::function<void()> &&callback, bool useOldTLS, bool validateCert, const std::string &hostname, const std::vector<std::pair<std::string, std::string>> &sslConfCmds); void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx, std::function<void()> &&callback); #endif }; using TcpConnectionImplPtr = std::shared_ptr<TcpConnectionImpl>; } // namespace trantor