diff --git a/core/net/connections/tcp_connection_impl.cpp b/core/net/connections/tcp_connection_impl.cpp index b39c35b..fc35606 100644 --- a/core/net/connections/tcp_connection_impl.cpp +++ b/core/net/connections/tcp_connection_impl.cpp @@ -67,12 +67,11 @@ inline bool loadWindowsSystemCert(X509_STORE *store) { } PCCERT_CONTEXT pContext = NULL; - while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != - nullptr) { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != nullptr) { + auto encoded_cert = static_cast(pContext->pbCertEncoded); auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { X509_STORE_add_cert(store, x509); X509_free(x509); @@ -125,15 +124,12 @@ inline bool verifyCommonName(X509 *cert, const std::string &hostname) { if (subjectName != nullptr) { std::array name; - auto length = X509_NAME_get_text_by_NID(subjectName, - NID_commonName, - name.data(), - (int)name.size()); + auto length = X509_NAME_get_text_by_NID(subjectName, NID_commonName, name.data(), (int)name.size()); + if (length == -1) return false; - return verifyName(std::string(name.begin(), name.begin() + length), - hostname); + return verifyName(std::string(name.begin(), name.begin() + length), hostname); } return false; @@ -141,8 +137,7 @@ inline bool verifyCommonName(X509 *cert, const std::string &hostname) { inline bool verifyAltName(X509 *cert, const std::string &hostname) { bool good = false; - auto altNames = static_cast( - X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); + auto altNames = static_cast(X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); if (altNames) { int numNames = sk_GENERAL_NAME_num(altNames); @@ -154,6 +149,7 @@ inline bool verifyAltName(X509 *cert, const std::string &hostname) { "an issue if you need that feature"; continue; } + #if (OPENSSL_VERSION_NUMBER >= 0x10100000L) auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); #else @@ -171,11 +167,10 @@ inline bool verifyAltName(X509 *cert, const std::string &hostname) { } // namespace internal void initOpenSSL() { -#if (OPENSSL_VERSION_NUMBER < 0x10100000L) || \ - (defined(LIBRESSL_VERSION_NUMBER) && \ - LIBRESSL_VERSION_NUMBER < 0x20700000L) +#if (OPENSSL_VERSION_NUMBER < 0x10100000L) || (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x20700000L) // Initialize OpenSSL once; static std::once_flag once; + std::call_once(once, []() { SSL_library_init(); ERR_load_crypto_strings(); @@ -187,10 +182,7 @@ void initOpenSSL() { class SSLContext { public: - explicit SSLContext( - bool useOldTLS, - bool enableValidtion, - const std::vector > &sslConfCmds) { + explicit SSLContext(bool useOldTLS, bool enableValidtion, const std::vector > &sslConfCmds) { #if (OPENSSL_VERSION_NUMBER >= 0x10100000L) ctxPtr_ = SSL_CTX_new(TLS_method()); SSL_CONF_CTX *cctx = SSL_CONF_CTX_new(); @@ -199,9 +191,11 @@ public: SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CERTIFICATE); SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_FILE); SSL_CONF_CTX_set_ssl_ctx(cctx, ctxPtr_); + for (const auto &cmd : sslConfCmds) { SSL_CONF_cmd(cctx, cmd.first.data(), cmd.second.data()); } + SSL_CONF_CTX_finish(cctx); if (!useOldTLS) { SSL_CTX_set_min_proto_version(ctxPtr_, TLS1_2_VERSION); @@ -210,6 +204,7 @@ public: "obsolete, insecure standards and should only be " "used for legacy purpose."; } + #else ctxPtr_ = SSL_CTX_new(SSLv23_method()); SSL_CONF_CTX *cctx = SSL_CONF_CTX_new(); @@ -218,10 +213,13 @@ public: SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_CERTIFICATE); SSL_CONF_CTX_set_flags(cctx, SSL_CONF_FLAG_FILE); SSL_CONF_CTX_set_ssl_ctx(cctx, ctxPtr_); + for (const auto &cmd : sslConfCmds) { SSL_CONF_cmd(cctx, cmd.first.data(), cmd.second.data()); } + SSL_CONF_CTX_finish(cctx); + if (!useOldTLS) { SSL_CTX_set_options(ctxPtr_, SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1); } else { @@ -238,6 +236,7 @@ public: SSL_CTX_set_default_verify_paths(ctxPtr_); #endif } + ~SSLContext() { if (ctxPtr_) { SSL_CTX_free(ctxPtr_); @@ -251,6 +250,7 @@ public: private: SSL_CTX *ctxPtr_; }; + class SSLConn { public: explicit SSLConn(SSL_CTX *ctx) { @@ -269,20 +269,17 @@ private: SSL *SSL_; }; -std::shared_ptr newSSLContext( - bool useOldTLS, - bool validateCert, - const std::vector > &sslConfCmds) { // init OpenSSL +std::shared_ptr newSSLContext(bool useOldTLS, bool validateCert, const std::vector > &sslConfCmds) { // init OpenSSL initOpenSSL(); return std::make_shared(useOldTLS, validateCert, sslConfCmds); } -std::shared_ptr newSSLServerContext( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS, - const std::vector > &sslConfCmds) { + +std::shared_ptr newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector > &sslConfCmds) { + auto ctx = newSSLContext(useOldTLS, false, sslConfCmds); + auto r = SSL_CTX_use_certificate_chain_file(ctx->get(), certPath.c_str()); + if (!r) { #ifndef _MSC_VER LOG_FATAL << strerror(errno); @@ -291,9 +288,9 @@ std::shared_ptr newSSLServerContext( #endif abort(); } - r = SSL_CTX_use_PrivateKey_file(ctx->get(), - keyPath.c_str(), - SSL_FILETYPE_PEM); + + r = SSL_CTX_use_PrivateKey_file(ctx->get(), keyPath.c_str(), SSL_FILETYPE_PEM); + if (!r) { #ifndef _MSC_VER LOG_FATAL << strerror(errno); @@ -302,7 +299,9 @@ std::shared_ptr newSSLServerContext( #endif abort(); } + r = SSL_CTX_check_private_key(ctx->get()); + if (!r) { #ifndef _MSC_VER LOG_FATAL << strerror(errno); @@ -311,113 +310,100 @@ std::shared_ptr newSSLServerContext( #endif abort(); } + return ctx; } #else -std::shared_ptr newSSLServerContext( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS, - const std::vector > &sslConfCmds) { +std::shared_ptr newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector > &sslConfCmds) { LOG_FATAL << "OpenSSL is not found in your system!"; abort(); } #endif -TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop, - int socketfd, - const InetAddress &localAddr, - const InetAddress &peerAddr) : - loop_(loop), - ioChannelPtr_(new Channel(loop, socketfd)), - socketPtr_(new Socket(socketfd)), - localAddr_(localAddr), - peerAddr_(peerAddr) { - LOG_TRACE << "new connection:" << peerAddr.toIpPort() << "->" - << localAddr.toIpPort(); - ioChannelPtr_->setReadCallback( - std::bind(&TcpConnectionImpl::readCallback, this)); - ioChannelPtr_->setWriteCallback( - std::bind(&TcpConnectionImpl::writeCallback, this)); - ioChannelPtr_->setCloseCallback( - std::bind(&TcpConnectionImpl::handleClose, this)); - ioChannelPtr_->setErrorCallback( - std::bind(&TcpConnectionImpl::handleError, this)); +TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr) : + loop_(loop), ioChannelPtr_(new Channel(loop, socketfd)), socketPtr_(new Socket(socketfd)), localAddr_(localAddr), peerAddr_(peerAddr) { + LOG_TRACE << "new connection:" << peerAddr.toIpPort() << "->" << localAddr.toIpPort(); + + ioChannelPtr_->setReadCallback(std::bind(&TcpConnectionImpl::readCallback, this)); + ioChannelPtr_->setWriteCallback(std::bind(&TcpConnectionImpl::writeCallback, this)); + ioChannelPtr_->setCloseCallback(std::bind(&TcpConnectionImpl::handleClose, this)); + ioChannelPtr_->setErrorCallback(std::bind(&TcpConnectionImpl::handleError, this)); socketPtr_->setKeepAlive(true); + name_ = localAddr.toIpPort() + "--" + peerAddr.toIpPort(); } + TcpConnectionImpl::~TcpConnectionImpl() { } + #ifdef USE_OPENSSL -void TcpConnectionImpl::startClientEncryptionInLoop( - std::function &&callback, - bool useOldTLS, - bool validateCert, - const std::string &hostname, - const std::vector > &sslConfCmds) { +void TcpConnectionImpl::startClientEncryptionInLoop(std::function &&callback, bool useOldTLS, bool validateCert, const std::string &hostname, const std::vector > &sslConfCmds) { validateCert_ = validateCert; loop_->assertInLoopThread(); + if (isEncrypted_) { LOG_WARN << "This connection is already encrypted"; return; } + sslEncryptionPtr_ = std::make_unique(); sslEncryptionPtr_->upgradeCallback_ = std::move(callback); - sslEncryptionPtr_->sslCtxPtr_ = - newSSLContext(useOldTLS, validateCert_, sslConfCmds); - sslEncryptionPtr_->sslPtr_ = - std::make_unique(sslEncryptionPtr_->sslCtxPtr_->get()); + sslEncryptionPtr_->sslCtxPtr_ = newSSLContext(useOldTLS, validateCert_, sslConfCmds); + sslEncryptionPtr_->sslPtr_ = std::make_unique(sslEncryptionPtr_->sslCtxPtr_->get()); + if (validateCert) { - SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), - SSL_VERIFY_NONE, - nullptr); + SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), SSL_VERIFY_NONE, nullptr); validateCert_ = validateCert; } + if (!hostname.empty()) { SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(), hostname.data()); sslEncryptionPtr_->hostname_ = hostname; } + isEncrypted_ = true; sslEncryptionPtr_->isUpgrade_ = true; auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketPtr_->fd()); (void)r; assert(r); - sslEncryptionPtr_->sendBufferPtr_ = - std::make_unique >(); + + sslEncryptionPtr_->sendBufferPtr_ = std::make_unique >(); LOG_TRACE << "connectEstablished"; ioChannelPtr_->enableWriting(); SSL_set_connect_state(sslEncryptionPtr_->sslPtr_->get()); } -void TcpConnectionImpl::startServerEncryptionInLoop( - const std::shared_ptr &ctx, - std::function &&callback) { + +void TcpConnectionImpl::startServerEncryptionInLoop(const std::shared_ptr &ctx, std::function &&callback) { loop_->assertInLoopThread(); + if (isEncrypted_) { LOG_WARN << "This connection is already encrypted"; return; } + sslEncryptionPtr_ = std::make_unique(); sslEncryptionPtr_->upgradeCallback_ = std::move(callback); sslEncryptionPtr_->sslCtxPtr_ = ctx; sslEncryptionPtr_->isServer_ = true; - sslEncryptionPtr_->sslPtr_ = - std::make_unique(sslEncryptionPtr_->sslCtxPtr_->get()); + sslEncryptionPtr_->sslPtr_ = std::make_unique(sslEncryptionPtr_->sslCtxPtr_->get()); isEncrypted_ = true; sslEncryptionPtr_->isUpgrade_ = true; - if (sslEncryptionPtr_->isServer_ == false) - SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), - SSL_VERIFY_NONE, - nullptr); + + if (sslEncryptionPtr_->isServer_ == false) { + SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), SSL_VERIFY_NONE, nullptr); + } + auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketPtr_->fd()); (void)r; assert(r); - sslEncryptionPtr_->sendBufferPtr_ = - std::make_unique >(); + + sslEncryptionPtr_->sendBufferPtr_ = std::make_unique >(); LOG_TRACE << "upgrade to ssl"; + SSL_set_accept_state(sslEncryptionPtr_->sslPtr_->get()); } #endif @@ -440,16 +426,13 @@ void TcpConnectionImpl::startServerEncryption( #endif } -void TcpConnectionImpl::startClientEncryption( - std::function callback, - bool useOldTLS, - bool validateCert, - std::string hostname, - const std::vector > &sslConfCmds) { + +void TcpConnectionImpl::startClientEncryption(std::function callback, bool useOldTLS, bool validateCert, std::string hostname, const std::vector > &sslConfCmds) { #ifndef USE_OPENSSL LOG_FATAL << "OpenSSL is not found in your system!"; abort(); #else + if (!hostname.empty()) { std::transform(hostname.begin(), hostname.end(), @@ -458,28 +441,17 @@ void TcpConnectionImpl::startClientEncryption( assert(sslEncryptionPtr_ != nullptr); sslEncryptionPtr_->hostname_ = hostname; } + if (loop_->isInLoopThread()) { - startClientEncryptionInLoop(std::move(callback), - useOldTLS, - validateCert, - hostname, - sslConfCmds); + startClientEncryptionInLoop(std::move(callback), useOldTLS, validateCert, hostname, sslConfCmds); } else { - loop_->queueInLoop([thisPtr = shared_from_this(), - callback = std::move(callback), - useOldTLS, - hostname = std::move(hostname), - validateCert, - &sslConfCmds]() mutable { - thisPtr->startClientEncryptionInLoop(std::move(callback), - useOldTLS, - validateCert, - hostname, - sslConfCmds); + loop_->queueInLoop([thisPtr = shared_from_this(), callback = std::move(callback), useOldTLS, hostname = std::move(hostname), validateCert, &sslConfCmds]() mutable { + thisPtr->startClientEncryptionInLoop(std::move(callback), useOldTLS, validateCert, hostname, sslConfCmds); }); } #endif } + void TcpConnectionImpl::readCallback() { // LOG_TRACE<<"read Callback"; #ifdef USE_OPENSSL @@ -510,17 +482,21 @@ void TcpConnectionImpl::readCallback() { handleClose(); return; } + extendLife(); + if (n > 0) { bytesReceived_ += n; if (recvMsgCallback_) { recvMsgCallback_(shared_from_this(), &readBuffer_); } } + #ifdef USE_OPENSSL } else { LOG_TRACE << "read Callback"; loop_->assertInLoopThread(); + if (sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Handshaking) { doHandshaking(); return; @@ -564,8 +540,10 @@ void TcpConnectionImpl::extendLife() { auto now = Date::date(); if (now < lastTimingWheelUpdateTime_.after(1.0)) return; + lastTimingWheelUpdateTime_ = now; auto entry = kickoffEntry_.lock(); + if (entry) { auto timingWheelPtr = timingWheelWeakPtr_.lock(); if (timingWheelPtr) @@ -575,12 +553,11 @@ void TcpConnectionImpl::extendLife() { } void TcpConnectionImpl::writeCallback() { #ifdef USE_OPENSSL - if (!isEncrypted_ || - (sslEncryptionPtr_ && - sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Connected)) { + if (!isEncrypted_ || (sslEncryptionPtr_ && sslEncryptionPtr_->statusOfSSL_ == SSLStatus::Connected)) { #endif loop_->assertInLoopThread(); extendLife(); + if (ioChannelPtr_->isWriting()) { assert(!writeBufferList_.empty()); auto writeBuffer_ = writeBufferList_.front(); @@ -591,6 +568,7 @@ void TcpConnectionImpl::writeCallback() { #endif { if (writeBuffer_->msgBuffer_->readableBytes() <= 0) { + writeBufferList_.pop_front(); if (writeBufferList_.empty()) { ioChannelPtr_->disableWriting(); @@ -600,6 +578,7 @@ void TcpConnectionImpl::writeCallback() { if (status_ == ConnStatus::Disconnecting) { socketPtr_->closeWrite(); } + } else { auto fileNode = writeBufferList_.front(); #ifndef _WIN32 @@ -611,8 +590,7 @@ void TcpConnectionImpl::writeCallback() { } } else { auto n = - writeInLoop(writeBuffer_->msgBuffer_->peek(), - writeBuffer_->msgBuffer_->readableBytes()); + writeInLoop(writeBuffer_->msgBuffer_->peek(), writeBuffer_->msgBuffer_->readableBytes()); if (n >= 0) { writeBuffer_->msgBuffer_->retrieve(n); } else { @@ -636,6 +614,7 @@ void TcpConnectionImpl::writeCallback() { } else { // file if (writeBuffer_->fileBytesToSend_ <= 0) { + writeBufferList_.pop_front(); if (writeBufferList_.empty()) { ioChannelPtr_->disableWriting(); @@ -644,6 +623,7 @@ void TcpConnectionImpl::writeCallback() { if (status_ == ConnStatus::Disconnecting) { socketPtr_->closeWrite(); } + } else { #ifndef _WIN32 if (writeBufferList_.front()->sendFd_ < 0) @@ -652,13 +632,10 @@ void TcpConnectionImpl::writeCallback() { #endif { // There is data to be sent in the buffer. - auto n = writeInLoop( - writeBufferList_.front()->msgBuffer_->peek(), - writeBufferList_.front() - ->msgBuffer_->readableBytes()); + auto n = writeInLoop(writeBufferList_.front()->msgBuffer_->peek(), writeBufferList_.front()->msgBuffer_->readableBytes()); + if (n >= 0) { - writeBufferList_.front()->msgBuffer_->retrieve( - n); + writeBufferList_.front()->msgBuffer_->retrieve(n); } else { #ifdef _WIN32 if (errno != 0 && errno != EWOULDBLOCK) @@ -701,6 +678,7 @@ void TcpConnectionImpl::writeCallback() { } #endif } + void TcpConnectionImpl::connectEstablished() { // loop_->assertInLoopThread(); #ifdef USE_OPENSSL @@ -713,8 +691,9 @@ void TcpConnectionImpl::connectEstablished() { thisPtr->ioChannelPtr_->tie(thisPtr); thisPtr->ioChannelPtr_->enableReading(); thisPtr->status_ = ConnStatus::Connected; - if (thisPtr->connectionCallback_) + if (thisPtr->connectionCallback_) { thisPtr->connectionCallback_(thisPtr); + } }); #ifdef USE_OPENSSL } else { @@ -724,27 +703,31 @@ void TcpConnectionImpl::connectEstablished() { thisPtr->ioChannelPtr_->tie(thisPtr); thisPtr->ioChannelPtr_->enableReading(); thisPtr->status_ = ConnStatus::Connected; + if (thisPtr->sslEncryptionPtr_->isServer_) { - SSL_set_accept_state( - thisPtr->sslEncryptionPtr_->sslPtr_->get()); + SSL_set_accept_state(thisPtr->sslEncryptionPtr_->sslPtr_->get()); } else { thisPtr->ioChannelPtr_->enableWriting(); - SSL_set_connect_state( - thisPtr->sslEncryptionPtr_->sslPtr_->get()); + SSL_set_connect_state(thisPtr->sslEncryptionPtr_->sslPtr_->get()); } }); } #endif } + void TcpConnectionImpl::handleClose() { LOG_TRACE << "connection closed, fd=" << socketPtr_->fd(); loop_->assertInLoopThread(); status_ = ConnStatus::Disconnected; ioChannelPtr_->disableAll(); + // ioChannelPtr_->remove(); auto guardThis = shared_from_this(); - if (connectionCallback_) + + if (connectionCallback_) { connectionCallback_(guardThis); + } + if (closeCallback_) { LOG_TRACE << "to call close callback"; closeCallback_(guardThis); @@ -752,8 +735,10 @@ void TcpConnectionImpl::handleClose() { } void TcpConnectionImpl::handleError() { int err = socketPtr_->getSocketError(); + if (err == 0) return; + if (err == EPIPE || err == ECONNRESET || err == 104) { LOG_DEBUG << "[" << name_ << "] - SO_ERROR = " << err << " " << strerror_tl(err); @@ -762,11 +747,14 @@ void TcpConnectionImpl::handleError() { << strerror_tl(err); } } + void TcpConnectionImpl::setTcpNoDelay(bool on) { socketPtr_->setTcpNoDelay(on); } + void TcpConnectionImpl::connectDestroyed() { loop_->assertInLoopThread(); + if (status_ == ConnStatus::Connected) { status_ = ConnStatus::Disconnected; ioChannelPtr_->disableAll(); @@ -777,6 +765,7 @@ void TcpConnectionImpl::connectDestroyed() { } void TcpConnectionImpl::shutdown() { auto thisPtr = shared_from_this(); + loop_->runInLoop([thisPtr]() { if (thisPtr->status_ == ConnStatus::Connected) { thisPtr->status_ = ConnStatus::Disconnecting; @@ -789,6 +778,7 @@ void TcpConnectionImpl::shutdown() { void TcpConnectionImpl::forceClose() { auto thisPtr = shared_from_this(); + loop_->runInLoop([thisPtr]() { if (thisPtr->status_ == ConnStatus::Connected || thisPtr->status_ == ConnStatus::Disconnecting) { @@ -797,6 +787,7 @@ void TcpConnectionImpl::forceClose() { } }); } + #ifndef _WIN32 void TcpConnectionImpl::sendInLoop(const void *buffer, size_t length) #else @@ -804,16 +795,20 @@ void TcpConnectionImpl::sendInLoop(const char *buffer, size_t length) #endif { loop_->assertInLoopThread(); + if (status_ != ConnStatus::Connected) { LOG_WARN << "Connection is not connected,give up sending"; return; } + extendLife(); size_t remainLen = length; ssize_t sendLen = 0; + if (!ioChannelPtr_->isWriting() && writeBufferList_.empty()) { // send directly sendLen = writeInLoop(buffer, length); + if (sendLen < 0) { // error #ifdef _WIN32 @@ -827,6 +822,7 @@ void TcpConnectionImpl::sendInLoop(const char *buffer, size_t length) LOG_DEBUG << "EPIPE or ECONNRESET, erron=" << errno; return; } + LOG_SYSERR << "Unexpected error(" << errno << ")"; return; } @@ -834,12 +830,15 @@ void TcpConnectionImpl::sendInLoop(const char *buffer, size_t length) } remainLen -= sendLen; } + if (remainLen > 0 && status_ == ConnStatus::Connected) { + if (writeBufferList_.empty()) { BufferNodePtr node(new BufferNode); node->msgBuffer_ = std::make_shared(); writeBufferList_.push_back(std::move(node)); } + #ifndef _WIN32 else if (writeBufferList_.back()->sendFd_ >= 0) #else @@ -850,16 +849,15 @@ void TcpConnectionImpl::sendInLoop(const char *buffer, size_t length) node->msgBuffer_ = std::make_shared(); writeBufferList_.push_back(std::move(node)); } - writeBufferList_.back()->msgBuffer_->append( - static_cast(buffer) + sendLen, remainLen); - if (!ioChannelPtr_->isWriting()) + + writeBufferList_.back()->msgBuffer_->append(static_cast(buffer) + sendLen, remainLen); + + if (!ioChannelPtr_->isWriting()) { ioChannelPtr_->enableWriting(); - if (highWaterMarkCallback_ && - writeBufferList_.back()->msgBuffer_->readableBytes() > - highWaterMarkLen_) { - highWaterMarkCallback_( - shared_from_this(), - writeBufferList_.back()->msgBuffer_->readableBytes()); + } + + if (highWaterMarkCallback_ && writeBufferList_.back()->msgBuffer_->readableBytes() > highWaterMarkLen_) { + highWaterMarkCallback_(shared_from_this(), writeBufferList_.back()->msgBuffer_->readableBytes()); } } } @@ -872,6 +870,7 @@ void TcpConnectionImpl::send(const std::shared_ptr &msgPtr) { } else { ++sendNum_; auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, msgPtr]() { thisPtr->sendInLoop(msgPtr->data(), msgPtr->length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -882,6 +881,7 @@ void TcpConnectionImpl::send(const std::shared_ptr &msgPtr) { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, msgPtr]() { thisPtr->sendInLoop(msgPtr->data(), msgPtr->length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -898,6 +898,7 @@ void TcpConnectionImpl::send(const std::shared_ptr &msgPtr) { } else { ++sendNum_; auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, msgPtr]() { thisPtr->sendInLoop(msgPtr->peek(), msgPtr->readableBytes()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -908,6 +909,7 @@ void TcpConnectionImpl::send(const std::shared_ptr &msgPtr) { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, msgPtr]() { thisPtr->sendInLoop(msgPtr->peek(), msgPtr->readableBytes()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -924,6 +926,7 @@ void TcpConnectionImpl::send(const char *msg, size_t len) { ++sendNum_; auto buffer = std::make_shared(msg, len); auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, buffer]() { thisPtr->sendInLoop(buffer->data(), buffer->length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -935,6 +938,7 @@ void TcpConnectionImpl::send(const char *msg, size_t len) { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, buffer]() { thisPtr->sendInLoop(buffer->data(), buffer->length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -957,6 +961,7 @@ void TcpConnectionImpl::send(const void *msg, size_t len) { std::make_shared(static_cast(msg), len); auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, buffer]() { thisPtr->sendInLoop(buffer->data(), buffer->length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -969,6 +974,7 @@ void TcpConnectionImpl::send(const void *msg, size_t len) { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, buffer]() { thisPtr->sendInLoop(buffer->data(), buffer->length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -984,6 +990,7 @@ void TcpConnectionImpl::send(const std::string &msg) { } else { ++sendNum_; auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, msg]() { thisPtr->sendInLoop(msg.data(), msg.length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -994,6 +1001,7 @@ void TcpConnectionImpl::send(const std::string &msg) { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, msg]() { thisPtr->sendInLoop(msg.data(), msg.length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -1009,6 +1017,7 @@ void TcpConnectionImpl::send(std::string &&msg) { } else { auto thisPtr = shared_from_this(); ++sendNum_; + loop_->queueInLoop([thisPtr, msg = std::move(msg)]() { thisPtr->sendInLoop(msg.data(), msg.length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -1019,6 +1028,7 @@ void TcpConnectionImpl::send(std::string &&msg) { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, msg = std::move(msg)]() { thisPtr->sendInLoop(msg.data(), msg.length()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -1035,6 +1045,7 @@ void TcpConnectionImpl::send(const MsgBuffer &buffer) { } else { ++sendNum_; auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, buffer]() { thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -1045,6 +1056,7 @@ void TcpConnectionImpl::send(const MsgBuffer &buffer) { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, buffer]() { thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -1054,23 +1066,28 @@ void TcpConnectionImpl::send(const MsgBuffer &buffer) { } void TcpConnectionImpl::send(MsgBuffer &&buffer) { + if (loop_->isInLoopThread()) { std::lock_guard guard(sendNumMutex_); + if (sendNum_ == 0) { sendInLoop(buffer.peek(), buffer.readableBytes()); } else { ++sendNum_; auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, buffer = std::move(buffer)]() { thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); std::lock_guard guard1(thisPtr->sendNumMutex_); --thisPtr->sendNum_; }); } + } else { auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, buffer = std::move(buffer)]() { thisPtr->sendInLoop(buffer.peek(), buffer.readableBytes()); std::lock_guard guard1(thisPtr->sendNumMutex_); @@ -1078,10 +1095,11 @@ void TcpConnectionImpl::send(MsgBuffer &&buffer) { }); } } -void TcpConnectionImpl::sendFile(const char *fileName, - size_t offset, - size_t length) { + +void TcpConnectionImpl::sendFile(const char *fileName, size_t offset, size_t length) { + assert(fileName); + #ifndef _WIN32 int fd = open(fileName, O_RDONLY); @@ -1097,6 +1115,7 @@ void TcpConnectionImpl::sendFile(const char *fileName, close(fd); return; } + length = filestat.st_size; } @@ -1122,6 +1141,7 @@ void TcpConnectionImpl::sendFile(const char *fileName, fclose(fp); return; } + length = filestat.st_size; } @@ -1136,6 +1156,7 @@ void TcpConnectionImpl::sendFile(FILE *fp, size_t offset, size_t length) #endif { assert(length > 0); + #ifndef _WIN32 assert(sfd >= 0); BufferNodePtr node(new BufferNode); @@ -1145,19 +1166,25 @@ void TcpConnectionImpl::sendFile(FILE *fp, size_t offset, size_t length) BufferNodePtr node(new BufferNode); node->sendFp_ = fp; #endif + node->offset_ = static_cast(offset); node->fileBytesToSend_ = length; + if (loop_->isInLoopThread()) { std::lock_guard guard(sendNumMutex_); if (sendNum_ == 0) { + writeBufferList_.push_back(node); if (writeBufferList_.size() == 1) { sendFileInLoop(writeBufferList_.front()); return; } + } else { + ++sendNum_; auto thisPtr = shared_from_this(); + loop_->queueInLoop([thisPtr, node]() { thisPtr->writeBufferList_.push_back(node); { @@ -1171,9 +1198,11 @@ void TcpConnectionImpl::sendFile(FILE *fp, size_t offset, size_t length) }); } } else { + auto thisPtr = shared_from_this(); std::lock_guard guard(sendNumMutex_); ++sendNum_; + loop_->queueInLoop([thisPtr, node]() { LOG_TRACE << "Push sendfile to list"; thisPtr->writeBufferList_.push_back(node); @@ -1191,49 +1220,56 @@ void TcpConnectionImpl::sendFile(FILE *fp, size_t offset, size_t length) } void TcpConnectionImpl::sendFileInLoop(const BufferNodePtr &filePtr) { + loop_->assertInLoopThread(); + #ifndef _WIN32 assert(filePtr->sendFd_ >= 0); #else assert(filePtr->sendFp_); #endif + #ifdef __linux__ if (!isEncrypted_) { - auto bytesSent = sendfile(socketPtr_->fd(), - filePtr->sendFd_, - &filePtr->offset_, - filePtr->fileBytesToSend_); + auto bytesSent = sendfile(socketPtr_->fd(), filePtr->sendFd_, &filePtr->offset_, filePtr->fileBytesToSend_); + if (bytesSent < 0) { if (errno != EAGAIN) { LOG_SYSERR << "TcpConnectionImpl::sendFileInLoop"; - if (ioChannelPtr_->isWriting()) + if (ioChannelPtr_->isWriting()) { ioChannelPtr_->disableWriting(); + } } return; } + if (bytesSent < filePtr->fileBytesToSend_) { if (bytesSent == 0) { LOG_SYSERR << "TcpConnectionImpl::sendFileInLoop"; return; } } + LOG_TRACE << "sendfile() " << bytesSent << " bytes sent"; filePtr->fileBytesToSend_ -= bytesSent; if (!ioChannelPtr_->isWriting()) { ioChannelPtr_->enableWriting(); } + return; } #endif + #ifndef _WIN32 lseek(filePtr->sendFd_, filePtr->offset_, SEEK_SET); + if (!fileBufferPtr_) { fileBufferPtr_ = std::make_unique >(16 * 1024); } + while (filePtr->fileBytesToSend_ > 0) { - auto n = read(filePtr->sendFd_, - &(*fileBufferPtr_)[0], - fileBufferPtr_->size()); + auto n = read(filePtr->sendFd_, &(*fileBufferPtr_)[0], fileBufferPtr_->size()); + #else _fseeki64(filePtr->sendFp_, filePtr->offset_, SEEK_SET); if (!fileBufferPtr_) { @@ -1245,6 +1281,7 @@ void TcpConnectionImpl::sendFileInLoop(const BufferNodePtr &filePtr) { fileBufferPtr_->size(), filePtr->sendFp_); #endif + if (n > 0) { auto nSend = writeInLoop(&(*fileBufferPtr_)[0], n); if (nSend >= 0) { @@ -1270,9 +1307,11 @@ void TcpConnectionImpl::sendFileInLoop(const BufferNodePtr &filePtr) { LOG_DEBUG << "EPIPE or ECONNRESET, erron=" << errno; return; } + LOG_SYSERR << "Unexpected error(" << errno << ")"; return; } + break; } } @@ -1282,15 +1321,18 @@ void TcpConnectionImpl::sendFileInLoop(const BufferNodePtr &filePtr) { ioChannelPtr_->disableWriting(); return; } + if (n == 0) { LOG_SYSERR << "read"; return; } } + if (!ioChannelPtr_->isWriting()) { ioChannelPtr_->enableWriting(); } } + #ifndef _WIN32 ssize_t TcpConnectionImpl::writeInLoop(const void *buffer, size_t length) #else @@ -1320,33 +1362,37 @@ ssize_t TcpConnectionImpl::writeInLoop(const char *buffer, size_t length) LOG_WARN << "SSL is not connected,give up sending"; return -1; } + // send directly size_t sendTotalLen = 0; while (sendTotalLen < length) { + auto len = length - sendTotalLen; + if (len > sslEncryptionPtr_->sendBufferPtr_->size()) { len = sslEncryptionPtr_->sendBufferPtr_->size(); } - memcpy(sslEncryptionPtr_->sendBufferPtr_->data(), - static_cast(buffer) + sendTotalLen, - len); + + memcpy(sslEncryptionPtr_->sendBufferPtr_->data(), static_cast(buffer) + sendTotalLen, len); + ERR_clear_error(); - auto sendLen = SSL_write(sslEncryptionPtr_->sslPtr_->get(), - sslEncryptionPtr_->sendBufferPtr_->data(), - static_cast(len)); + + auto sendLen = SSL_write(sslEncryptionPtr_->sslPtr_->get(), sslEncryptionPtr_->sendBufferPtr_->data(), static_cast(len)); + if (sendLen <= 0) { - int sslerr = - SSL_get_error(sslEncryptionPtr_->sslPtr_->get(), sendLen); - if (sslerr != SSL_ERROR_WANT_WRITE && - sslerr != SSL_ERROR_WANT_READ) { + int sslerr = SSL_get_error(sslEncryptionPtr_->sslPtr_->get(), sendLen); + + if (sslerr != SSL_ERROR_WANT_WRITE && sslerr != SSL_ERROR_WANT_READ) { // LOG_ERROR << "ssl write error:" << sslerr; forceClose(); return -1; } + return sendTotalLen; } sendTotalLen += sendLen; } + return sendTotalLen; } #endif @@ -1354,52 +1400,45 @@ ssize_t TcpConnectionImpl::writeInLoop(const char *buffer, size_t length) #ifdef USE_OPENSSL -TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop, - int socketfd, - const InetAddress &localAddr, - const InetAddress &peerAddr, - const std::shared_ptr &ctxPtr, - bool isServer, - bool validateCert, - const std::string &hostname) : +TcpConnectionImpl::TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, const std::shared_ptr &ctxPtr, bool isServer, bool validateCert, const std::string &hostname) : isEncrypted_(true), loop_(loop), ioChannelPtr_(new Channel(loop, socketfd)), socketPtr_(new Socket(socketfd)), localAddr_(localAddr), peerAddr_(peerAddr) { - LOG_TRACE << "new connection:" << peerAddr.toIpPort() << "->" - << localAddr.toIpPort(); - ioChannelPtr_->setReadCallback( - std::bind(&TcpConnectionImpl::readCallback, this)); - ioChannelPtr_->setWriteCallback( - std::bind(&TcpConnectionImpl::writeCallback, this)); - ioChannelPtr_->setCloseCallback( - std::bind(&TcpConnectionImpl::handleClose, this)); - ioChannelPtr_->setErrorCallback( - std::bind(&TcpConnectionImpl::handleError, this)); + + LOG_TRACE << "new connection:" << peerAddr.toIpPort() << "->" << localAddr.toIpPort(); + + ioChannelPtr_->setReadCallback(std::bind(&TcpConnectionImpl::readCallback, this)); + ioChannelPtr_->setWriteCallback(std::bind(&TcpConnectionImpl::writeCallback, this)); + ioChannelPtr_->setCloseCallback(std::bind(&TcpConnectionImpl::handleClose, this)); + ioChannelPtr_->setErrorCallback(std::bind(&TcpConnectionImpl::handleError, this)); + socketPtr_->setKeepAlive(true); name_ = localAddr.toIpPort() + "--" + peerAddr.toIpPort(); sslEncryptionPtr_ = std::make_unique(); sslEncryptionPtr_->sslPtr_ = std::make_unique(ctxPtr->get()); sslEncryptionPtr_->isServer_ = isServer; validateCert_ = validateCert; - if (isServer == false) - SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), - SSL_VERIFY_NONE, - nullptr); + + if (isServer == false) { + SSL_set_verify(sslEncryptionPtr_->sslPtr_->get(), SSL_VERIFY_NONE, nullptr); + } + if (!isServer && !hostname.empty()) { - SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(), - hostname.data()); + SSL_set_tlsext_host_name(sslEncryptionPtr_->sslPtr_->get(), hostname.data()); sslEncryptionPtr_->hostname_ = hostname; } + assert(sslEncryptionPtr_->sslPtr_); auto r = SSL_set_fd(sslEncryptionPtr_->sslPtr_->get(), socketfd); + (void)r; assert(r); + isEncrypted_ = true; - sslEncryptionPtr_->sendBufferPtr_ = - std::make_unique >(); + sslEncryptionPtr_->sendBufferPtr_ = std::make_unique >(); } bool TcpConnectionImpl::validatePeerCertificate() { @@ -1409,6 +1448,7 @@ bool TcpConnectionImpl::validatePeerCertificate() { SSL *ssl = sslEncryptionPtr_->sslPtr_->get(); auto result = SSL_get_verify_result(ssl); + if (result != X509_V_OK) { LOG_DEBUG << "cert error code: " << result; LOG_ERROR << "Server certificate is not valid"; @@ -1416,14 +1456,13 @@ bool TcpConnectionImpl::validatePeerCertificate() { } X509 *cert = SSL_get_peer_certificate(ssl); + if (cert == nullptr) { LOG_ERROR << "Unable to obtain peer certificate"; return false; } - bool domainIsValid = - internal::verifyCommonName(cert, sslEncryptionPtr_->hostname_) || - internal::verifyAltName(cert, sslEncryptionPtr_->hostname_); + bool domainIsValid = internal::verifyCommonName(cert, sslEncryptionPtr_->hostname_) || internal::verifyAltName(cert, sslEncryptionPtr_->hostname_); X509_free(cert); if (domainIsValid) { @@ -1439,10 +1478,12 @@ void TcpConnectionImpl::doHandshaking() { int r = SSL_do_handshake(sslEncryptionPtr_->sslPtr_->get()); LOG_TRACE << "hand shaking: " << r; + if (r == 1) { // Clients don't commonly have certificates. Let's not validate // that if (validateCert_ && sslEncryptionPtr_->isServer_ == false) { + if (validatePeerCertificate() == false) { LOG_ERROR << "SSL certificate validation failed."; ioChannelPtr_->disableReading(); @@ -1454,7 +1495,9 @@ void TcpConnectionImpl::doHandshaking() { return; } } + sslEncryptionPtr_->statusOfSSL_ = SSLStatus::Connected; + if (sslEncryptionPtr_->isUpgrade_) { sslEncryptionPtr_->upgradeCallback_(); } else { @@ -1462,25 +1505,33 @@ void TcpConnectionImpl::doHandshaking() { } return; } + int err = SSL_get_error(sslEncryptionPtr_->sslPtr_->get(), r); LOG_TRACE << "hand shaking: " << err; + if (err == SSL_ERROR_WANT_WRITE) { // SSL want writable; - if (!ioChannelPtr_->isWriting()) + if (!ioChannelPtr_->isWriting()) { ioChannelPtr_->enableWriting(); + } // ioChannelPtr_->disableReading(); } else if (err == SSL_ERROR_WANT_READ) { // SSL want readable; - if (!ioChannelPtr_->isReading()) + if (!ioChannelPtr_->isReading()) { ioChannelPtr_->enableReading(); - if (ioChannelPtr_->isWriting()) + } + + if (ioChannelPtr_->isWriting()) { ioChannelPtr_->disableWriting(); + } } else { // ERR_print_errors(err); LOG_TRACE << "SSL handshake err: " << err; ioChannelPtr_->disableReading(); + sslEncryptionPtr_->statusOfSSL_ = SSLStatus::DisConnected; if (sslErrorCallback_) { sslErrorCallback_(SSLError::kSSLHandshakeError); } + forceClose(); } } diff --git a/core/net/connections/tcp_connection_impl.h b/core/net/connections/tcp_connection_impl.h index 7bccab2..6ffb3a8 100644 --- a/core/net/connections/tcp_connection_impl.h +++ b/core/net/connections/tcp_connection_impl.h @@ -30,54 +30,45 @@ #pragma once -#include "core/net/tcp_connection.h" #include "core/loops/timing_wheel.h" +#include "core/net/tcp_connection.h" #include #include #ifndef _WIN32 #include #endif -#include #include - +#include #ifdef USE_OPENSSL -enum class SSLStatus -{ - Handshaking, - Connecting, - Connected, - DisConnecting, - DisConnected + +enum class SSLStatus { + Handshaking, + Connecting, + Connected, + DisConnecting, + DisConnected }; + class SSLContext; class SSLConn; -std::shared_ptr newSSLContext( - bool useOldTLS, - bool validateCert, - const std::vector> &sslConfCmds); -std::shared_ptr newSSLServerContext( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS, - const std::vector> &sslConfCmds); -// void initServerSSLContext(const std::shared_ptr &ctx, -// const std::string &certPath, -// const std::string &keyPath); +std::shared_ptr newSSLContext(bool useOldTLS, bool validateCert, const std::vector > &sslConfCmds); +std::shared_ptr newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector > &sslConfCmds); +// void initServerSSLContext(const std::shared_ptr &ctx, const std::string &certPath, const std::string &keyPath); + #endif + class Channel; class Socket; class TcpServer; void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn); -class TcpConnectionImpl : public TcpConnection, - public std::enable_shared_from_this -{ - friend class TcpServer; - friend class TcpClient; - friend void removeConnection(EventLoop *loop, - const TcpConnectionPtr &conn); +class TcpConnectionImpl : public TcpConnection, public std::enable_shared_from_this { + friend class TcpServer; + friend class TcpClient; + + friend void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn); protected: TcpConnectionImpl(const TcpConnectionImpl &) = delete; @@ -86,278 +77,247 @@ protected: TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default; TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default; - public: - class KickoffEntry - { - public: - explicit KickoffEntry(const std::weak_ptr &conn) - : conn_(conn) - { - } - void reset() - { - conn_.reset(); - } - ~KickoffEntry() - { - auto conn = conn_.lock(); - if (conn) - { - conn->forceClose(); - } - } +public: + class KickoffEntry { + public: + explicit KickoffEntry(const std::weak_ptr &conn) : + conn_(conn) { + } - private: - std::weak_ptr conn_; - }; + void reset() { + conn_.reset(); + } - TcpConnectionImpl(EventLoop *loop, - int socketfd, - const InetAddress &localAddr, - const InetAddress &peerAddr); + ~KickoffEntry() { + auto conn = conn_.lock(); + if (conn) { + conn->forceClose(); + } + } + + private: + std::weak_ptr conn_; + }; + + TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr); #ifdef USE_OPENSSL - TcpConnectionImpl(EventLoop *loop, - int socketfd, - const InetAddress &localAddr, - const InetAddress &peerAddr, - const std::shared_ptr &ctxPtr, - bool isServer = true, - bool validateCert = true, - const std::string &hostname = ""); + TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, const std::shared_ptr &ctxPtr, bool isServer = true, bool validateCert = true, const std::string &hostname = ""); #endif - virtual ~TcpConnectionImpl(); - virtual void send(const char *msg, size_t len) override; - virtual void send(const void *msg, size_t len) override; - virtual void send(const std::string &msg) override; - virtual void send(std::string &&msg) override; - virtual void send(const MsgBuffer &buffer) override; - virtual void send(MsgBuffer &&buffer) override; - virtual void send(const std::shared_ptr &msgPtr) override; - virtual void send(const std::shared_ptr &msgPtr) override; - virtual void sendFile(const char *fileName, - size_t offset = 0, - size_t length = 0) override; - virtual const InetAddress &localAddr() const override - { - return localAddr_; - } - virtual const InetAddress &peerAddr() const override - { - return peerAddr_; - } + virtual ~TcpConnectionImpl(); - virtual bool connected() const override - { - return status_ == ConnStatus::Connected; - } - virtual bool disconnected() const override - { - return status_ == ConnStatus::Disconnected; - } + virtual void send(const char *msg, size_t len) override; + virtual void send(const void *msg, size_t len) override; + virtual void send(const std::string &msg) override; + virtual void send(std::string &&msg) override; + virtual void send(const MsgBuffer &buffer) override; + virtual void send(MsgBuffer &&buffer) override; + virtual void send(const std::shared_ptr &msgPtr) override; + virtual void send(const std::shared_ptr &msgPtr) override; + virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) override; - // virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;} - virtual MsgBuffer *getRecvBuffer() override - { - return &readBuffer_; - } - // set callbacks - virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, - size_t markLen) override - { - highWaterMarkCallback_ = cb; - highWaterMarkLen_ = markLen; - } + virtual const InetAddress &localAddr() const override { + return localAddr_; + } - virtual void keepAlive() override - { - idleTimeout_ = 0; - auto entry = kickoffEntry_.lock(); - if (entry) - { - entry->reset(); - } - } - virtual bool isKeepAlive() override - { - return idleTimeout_ == 0; - } - virtual void setTcpNoDelay(bool on) override; - virtual void shutdown() override; - virtual void forceClose() override; - virtual EventLoop *getLoop() override - { - return loop_; - } + virtual const InetAddress &peerAddr() const override { + return peerAddr_; + } - virtual size_t bytesSent() const override - { - return bytesSent_; - } - virtual size_t bytesReceived() const override - { - return bytesReceived_; - } - virtual void startClientEncryption( - std::function callback, - bool useOldTLS = false, - bool validateCert = true, - std::string hostname = "", - const std::vector> &sslConfCmds = - {}) override; - virtual void startServerEncryption(const std::shared_ptr &ctx, - std::function callback) override; - virtual bool isSSLConnection() const override - { - return isEncrypted_; - } + virtual bool connected() const override { + return status_ == ConnStatus::Connected; + } - private: - /// Internal use only. + virtual bool disconnected() const override { + return status_ == ConnStatus::Disconnected; + } - std::weak_ptr kickoffEntry_; - std::weak_ptr timingWheelWeakPtr_; - size_t idleTimeout_{0}; - Date lastTimingWheelUpdateTime_; + // virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;} + virtual MsgBuffer *getRecvBuffer() override { + return &readBuffer_; + } - void enableKickingOff(size_t timeout, - const std::shared_ptr &timingWheel) - { - assert(timingWheel); - assert(timingWheel->getLoop() == loop_); - assert(timeout > 0); - auto entry = std::make_shared(shared_from_this()); - kickoffEntry_ = entry; - timingWheelWeakPtr_ = timingWheel; - idleTimeout_ = timeout; - timingWheel->insertEntry(timeout, entry); - } - void extendLife(); + // set callbacks + virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, + size_t markLen) override { + highWaterMarkCallback_ = cb; + highWaterMarkLen_ = markLen; + } + + virtual void keepAlive() override { + idleTimeout_ = 0; + auto entry = kickoffEntry_.lock(); + if (entry) { + entry->reset(); + } + } + + virtual bool isKeepAlive() override { + return idleTimeout_ == 0; + } + + virtual void setTcpNoDelay(bool on) override; + virtual void shutdown() override; + virtual void forceClose() override; + + virtual EventLoop *getLoop() override { + return loop_; + } + + virtual size_t bytesSent() const override { + return bytesSent_; + } + + virtual size_t bytesReceived() const override { + return bytesReceived_; + } + + virtual void startClientEncryption(std::function callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector > &sslConfCmds = {}) override; + virtual void startServerEncryption(const std::shared_ptr &ctx, std::function callback) override; + virtual bool isSSLConnection() const override { + return isEncrypted_; + } + +private: + /// Internal use only. + + std::weak_ptr kickoffEntry_; + std::weak_ptr timingWheelWeakPtr_; + size_t idleTimeout_{ 0 }; + Date lastTimingWheelUpdateTime_; + + void enableKickingOff(size_t timeout, + const std::shared_ptr &timingWheel) { + assert(timingWheel); + assert(timingWheel->getLoop() == loop_); + assert(timeout > 0); + auto entry = std::make_shared(shared_from_this()); + kickoffEntry_ = entry; + timingWheelWeakPtr_ = timingWheel; + idleTimeout_ = timeout; + timingWheel->insertEntry(timeout, entry); + } + + void extendLife(); #ifndef _WIN32 - void sendFile(int sfd, size_t offset = 0, size_t length = 0); + void sendFile(int sfd, size_t offset = 0, size_t length = 0); #else - void sendFile(FILE *fp, size_t offset = 0, size_t length = 0); + void sendFile(FILE *fp, size_t offset = 0, size_t length = 0); #endif - void setRecvMsgCallback(const RecvMessageCallback &cb) - { - recvMsgCallback_ = cb; - } - void setConnectionCallback(const ConnectionCallback &cb) - { - connectionCallback_ = cb; - } - void setWriteCompleteCallback(const WriteCompleteCallback &cb) - { - writeCompleteCallback_ = cb; - } - void setCloseCallback(const CloseCallback &cb) - { - closeCallback_ = cb; - } - void setSSLErrorCallback(const SSLErrorCallback &cb) - { - sslErrorCallback_ = cb; - } + void setRecvMsgCallback(const RecvMessageCallback &cb) { + recvMsgCallback_ = cb; + } + void setConnectionCallback(const ConnectionCallback &cb) { + connectionCallback_ = cb; + } + void setWriteCompleteCallback(const WriteCompleteCallback &cb) { + writeCompleteCallback_ = cb; + } + void setCloseCallback(const CloseCallback &cb) { + closeCallback_ = cb; + } + void setSSLErrorCallback(const SSLErrorCallback &cb) { + sslErrorCallback_ = cb; + } - void connectDestroyed(); - virtual void connectEstablished(); + void connectDestroyed(); + virtual void connectEstablished(); - protected: - struct BufferNode - { +protected: + struct BufferNode { #ifndef _WIN32 - int sendFd_{-1}; - off_t offset_; + int sendFd_{ -1 }; + off_t offset_; #else - FILE *sendFp_{nullptr}; - long long offset_; + FILE *sendFp_{ nullptr }; + long long offset_; #endif - ssize_t fileBytesToSend_; - std::shared_ptr msgBuffer_; - ~BufferNode() - { + ssize_t fileBytesToSend_; + std::shared_ptr msgBuffer_; + + ~BufferNode() { #ifndef _WIN32 - if (sendFd_ >= 0) - close(sendFd_); + if (sendFd_ >= 0) + close(sendFd_); #else - if (sendFp_) - fclose(sendFp_); + if (sendFp_) + fclose(sendFp_); #endif - } - }; - using BufferNodePtr = std::shared_ptr; - enum class ConnStatus - { - Disconnected, - Connecting, - Connected, - Disconnecting - }; - bool isEncrypted_{false}; - EventLoop *loop_; - std::unique_ptr ioChannelPtr_; - std::unique_ptr socketPtr_; - MsgBuffer readBuffer_; - std::list writeBufferList_; - void readCallback(); - void writeCallback(); - InetAddress localAddr_, peerAddr_; - ConnStatus status_{ConnStatus::Connecting}; - // callbacks - RecvMessageCallback recvMsgCallback_; - ConnectionCallback connectionCallback_; - CloseCallback closeCallback_; - WriteCompleteCallback writeCompleteCallback_; - HighWaterMarkCallback highWaterMarkCallback_; - SSLErrorCallback sslErrorCallback_; - void handleClose(); - void handleError(); - // virtual void sendInLoop(const std::string &msg); + } + }; - void sendFileInLoop(const BufferNodePtr &file); + using BufferNodePtr = std::shared_ptr; + enum class ConnStatus { + Disconnected, + Connecting, + Connected, + Disconnecting + }; + + bool isEncrypted_{ false }; + EventLoop *loop_; + std::unique_ptr ioChannelPtr_; + std::unique_ptr socketPtr_; + MsgBuffer readBuffer_; + std::list writeBufferList_; + + void readCallback(); + void writeCallback(); + + InetAddress localAddr_, peerAddr_; + ConnStatus status_{ ConnStatus::Connecting }; + // callbacks + RecvMessageCallback recvMsgCallback_; + ConnectionCallback connectionCallback_; + CloseCallback closeCallback_; + WriteCompleteCallback writeCompleteCallback_; + HighWaterMarkCallback highWaterMarkCallback_; + SSLErrorCallback sslErrorCallback_; + + void handleClose(); + void handleError(); + // virtual void sendInLoop(const std::string &msg); + + void sendFileInLoop(const BufferNodePtr &file); #ifndef _WIN32 - void sendInLoop(const void *buffer, size_t length); - ssize_t writeInLoop(const void *buffer, size_t length); + void sendInLoop(const void *buffer, size_t length); + ssize_t writeInLoop(const void *buffer, size_t length); #else - void sendInLoop(const char *buffer, size_t length); - ssize_t writeInLoop(const char *buffer, size_t length); + void sendInLoop(const char *buffer, size_t length); + ssize_t writeInLoop(const char *buffer, size_t length); #endif - size_t highWaterMarkLen_; - std::string name_; + size_t highWaterMarkLen_; + std::string name_; - uint64_t sendNum_{0}; - std::mutex sendNumMutex_; + uint64_t sendNum_{ 0 }; + std::mutex sendNumMutex_; - size_t bytesSent_{0}; - size_t bytesReceived_{0}; + size_t bytesSent_{ 0 }; + size_t bytesReceived_{ 0 }; - std::unique_ptr> fileBufferPtr_; + std::unique_ptr > fileBufferPtr_; #ifdef USE_OPENSSL - private: - void doHandshaking(); - bool validatePeerCertificate(); - struct SSLEncryption - { - SSLStatus statusOfSSL_ = SSLStatus::Handshaking; - // OpenSSL - std::shared_ptr sslCtxPtr_; - std::unique_ptr sslPtr_; - std::unique_ptr> sendBufferPtr_; - bool isServer_{false}; - bool isUpgrade_{false}; - std::function upgradeCallback_; - std::string hostname_; - }; - std::unique_ptr sslEncryptionPtr_; - void startClientEncryptionInLoop( - std::function &&callback, - bool useOldTLS, - bool validateCert, - const std::string &hostname, - const std::vector> &sslConfCmds); - void startServerEncryptionInLoop(const std::shared_ptr &ctx, - std::function &&callback); +private: + void doHandshaking(); + bool validatePeerCertificate(); + + struct SSLEncryption { + SSLStatus statusOfSSL_ = SSLStatus::Handshaking; + // OpenSSL + std::shared_ptr sslCtxPtr_; + std::unique_ptr sslPtr_; + std::unique_ptr > sendBufferPtr_; + bool isServer_{ false }; + bool isUpgrade_{ false }; + std::function upgradeCallback_; + std::string hostname_; + }; + + std::unique_ptr sslEncryptionPtr_; + + void startClientEncryptionInLoop(std::function &&callback, bool useOldTLS, bool validateCert, const std::string &hostname, const std::vector > &sslConfCmds); + void startServerEncryptionInLoop(const std::shared_ptr &ctx, std::function &&callback); #endif }; diff --git a/core/net/connector.cpp b/core/net/connector.cpp index 8337808..18a7fc3 100644 --- a/core/net/connector.cpp +++ b/core/net/connector.cpp @@ -35,6 +35,7 @@ Connector::Connector(EventLoop *loop, const InetAddress &addr, bool retry) : loop_(loop), serverAddr_(addr), retry_(retry) { } + Connector::Connector(EventLoop *loop, InetAddress &&addr, bool retry) : loop_(loop), serverAddr_(std::move(addr)), retry_(retry) { } @@ -43,8 +44,10 @@ void Connector::start() { connect_ = true; loop_->runInLoop([this]() { startInLoop(); }); } + void Connector::restart() { } + void Connector::stop() { } @@ -57,6 +60,7 @@ void Connector::startInLoop() { LOG_DEBUG << "do not connect"; } } + void Connector::connect() { int sockfd = Socket::createNonblockingSocketOrDie(serverAddr_.family()); errno = 0; diff --git a/core/net/connector.h b/core/net/connector.h index a2e7358..c10e41a 100644 --- a/core/net/connector.h +++ b/core/net/connector.h @@ -46,23 +46,30 @@ protected: public: using NewConnectionCallback = std::function; using ConnectionErrorCallback = std::function; + Connector(EventLoop *loop, const InetAddress &addr, bool retry = true); Connector(EventLoop *loop, InetAddress &&addr, bool retry = true); + void setNewConnectionCallback(const NewConnectionCallback &cb) { newConnectionCallback_ = cb; } + void setNewConnectionCallback(NewConnectionCallback &&cb) { newConnectionCallback_ = std::move(cb); } + void setErrorCallback(const ConnectionErrorCallback &cb) { errorCallback_ = cb; } + void setErrorCallback(ConnectionErrorCallback &&cb) { errorCallback_ = std::move(cb); } + const InetAddress &serverAddress() const { return serverAddr_; } + void start(); void restart(); void stop(); @@ -70,11 +77,13 @@ public: private: NewConnectionCallback newConnectionCallback_; ConnectionErrorCallback errorCallback_; + enum class Status { Disconnected, Connecting, Connected }; + static constexpr int kMaxRetryDelayMs = 30 * 1000; static constexpr int kInitRetryDelayMs = 500; std::shared_ptr channelPtr_; diff --git a/core/net/inet_address.cpp b/core/net/inet_address.cpp index 6c56e7a..ba3b9d9 100644 --- a/core/net/inet_address.cpp +++ b/core/net/inet_address.cpp @@ -98,8 +98,7 @@ InetAddress::InetAddress(uint16_t port, bool loopbackOnly, bool ipv6) : isUnspecified_ = false; } -InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) : - isIpV6_(ipv6) { +InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) : isIpV6_(ipv6) { if (ipv6) { memset(&addr6_, 0, sizeof(addr6_)); addr6_.sin6_family = AF_INET6; diff --git a/core/net/inet_address.h b/core/net/inet_address.h index b392266..45f716b 100644 --- a/core/net/inet_address.h +++ b/core/net/inet_address.h @@ -38,195 +38,76 @@ using sa_family_t = unsigned short; using in_addr_t = uint32_t; using uint16_t = unsigned short; #else -#include #include +#include #include #endif +#include #include #include -#include -/** - * @brief Wrapper of sockaddr_in. This is an POD interface class. - * - */ -class InetAddress -{ - public: - /** - * @brief Constructs an endpoint with given port number. Mostly used in - * TcpServer listening. - * - * @param port - * @param loopbackOnly - * @param ipv6 - */ - InetAddress(uint16_t port = 0, - bool loopbackOnly = false, - bool ipv6 = false); +class InetAddress { +public: + InetAddress(uint16_t port = 0, bool loopbackOnly = false, bool ipv6 = false); - /** - * @brief Constructs an endpoint with given ip and port. - * - * @param ip A IPv4 or IPv6 address. - * @param port - * @param ipv6 - */ - InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false); + InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false); - /** - * @brief Constructs an endpoint with given struct `sockaddr_in`. Mostly - * used when accepting new connections - * - * @param addr - */ - explicit InetAddress(const struct sockaddr_in &addr) - : addr_(addr), isUnspecified_(false) - { - } + explicit InetAddress(const struct sockaddr_in &addr) : + addr_(addr), isUnspecified_(false) { + } - /** - * @brief Constructs an IPv6 endpoint with given struct `sockaddr_in6`. - * Mostly used when accepting new connections - * - * @param addr - */ - explicit InetAddress(const struct sockaddr_in6 &addr) - : addr6_(addr), isIpV6_(true), isUnspecified_(false) - { - } + explicit InetAddress(const struct sockaddr_in6 &addr) : + addr6_(addr), isIpV6_(true), isUnspecified_(false) { + } - /** - * @brief Return the sin_family of the endpoint. - * - * @return sa_family_t - */ - sa_family_t family() const - { - return addr_.sin_family; - } + sa_family_t family() const { + return addr_.sin_family; + } - /** - * @brief Return the IP string of the endpoint. - * - * @return std::string - */ - std::string toIp() const; + std::string toIp() const; + std::string toIpPort() const; + uint16_t toPort() const; - /** - * @brief Return the IP and port string of the endpoint. - * - * @return std::string - */ - std::string toIpPort() const; + bool isIpV6() const { + return isIpV6_; + } - /** - * @brief Return the port number of the endpoint. - * - * @return uint16_t - */ - uint16_t toPort() const; + bool isIntranetIp() const; + bool isLoopbackIp() const; - /** - * @brief Check if the endpoint is IPv4 or IPv6. - * - * @return true - * @return false - */ - bool isIpV6() const - { - return isIpV6_; - } + const struct sockaddr *getSockAddr() const { + return static_cast((void *)(&addr6_)); + } - /** - * @brief Return true if the endpoint is an intranet endpoint. - * - * @return true - * @return false - */ - bool isIntranetIp() const; + void setSockAddrInet6(const struct sockaddr_in6 &addr6) { + addr6_ = addr6; + isIpV6_ = (addr6_.sin6_family == AF_INET6); + isUnspecified_ = false; + } - /** - * @brief Return true if the endpoint is a loopback endpoint. - * - * @return true - * @return false - */ - bool isLoopbackIp() const; + uint32_t ipNetEndian() const; + const uint32_t *ip6NetEndian() const; - /** - * @brief Get the pointer to the sockaddr struct. - * - * @return const struct sockaddr* - */ - const struct sockaddr *getSockAddr() const - { - return static_cast((void *)(&addr6_)); - } + uint16_t portNetEndian() const { + return addr_.sin_port; + } - /** - * @brief Set the sockaddr_in6 struct in the endpoint. - * - * @param addr6 - */ - void setSockAddrInet6(const struct sockaddr_in6 &addr6) - { - addr6_ = addr6; - isIpV6_ = (addr6_.sin6_family == AF_INET6); - isUnspecified_ = false; - } + void setPortNetEndian(uint16_t port) { + addr_.sin_port = port; + } - /** - * @brief Return the integer value of the IP(v4) in net endian byte order. - * - * @return uint32_t - */ - uint32_t ipNetEndian() const; + inline bool isUnspecified() const { + return isUnspecified_; + } - /** - * @brief Return the pointer to the integer value of the IP(v6) in net - * endian byte order. - * - * @return const uint32_t* - */ - const uint32_t *ip6NetEndian() const; +private: + union { + struct sockaddr_in addr_; + struct sockaddr_in6 addr6_; + }; - /** - * @brief Return the port number in net endian byte order. - * - * @return uint16_t - */ - uint16_t portNetEndian() const - { - return addr_.sin_port; - } - - /** - * @brief Set the port number in net endian byte order. - * - * @param port - */ - void setPortNetEndian(uint16_t port) - { - addr_.sin_port = port; - } - - /** - * @brief Return true if the address is not initalized. - */ - inline bool isUnspecified() const - { - return isUnspecified_; - } - - private: - union - { - struct sockaddr_in addr_; - struct sockaddr_in6 addr6_; - }; - bool isIpV6_{false}; - bool isUnspecified_{true}; + bool isIpV6_{ false }; + bool isUnspecified_{ true }; }; -#endif // MUDUO_NET_INETADDRESS_H +#endif // MUDUO_NET_INETADDRESS_H diff --git a/core/net/resolver.h b/core/net/resolver.h index 7ae58c1..1b2fdf7 100644 --- a/core/net/resolver.h +++ b/core/net/resolver.h @@ -34,44 +34,20 @@ #include "core/loops/event_loop.h" #include "core/net/inet_address.h" -/** - * @brief This class represents an asynchronous DNS resolver. - * @note Although the c-ares library is not essential, it is recommended to - * install it for higher performance - */ +//make it a reference + class Resolver { public: using Callback = std::function; - /** - * @brief Create a new DNS resolver. - * - * @param loop The event loop in which the DNS resolver runs. - * @param timeout The timeout in seconds for DNS. - * @return std::shared_ptr - */ - static std::shared_ptr newResolver(EventLoop* loop = nullptr, - size_t timeout = 60); + static std::shared_ptr newResolver(EventLoop* loop = nullptr, size_t timeout = 60); - /** - * @brief Resolve an address asynchronously. - * - * @param hostname - * @param callback - */ - virtual void resolve(const std::string& hostname, - const Callback& callback) = 0; + virtual void resolve(const std::string& hostname, const Callback& callback) = 0; virtual ~Resolver() { } - /** - * @brief Check whether the c-ares library is used. - * - * @return true - * @return false - */ static bool isCAresUsed(); }; diff --git a/core/net/resolvers/ares_resolver.cpp b/core/net/resolvers/ares_resolver.cpp index 8988d86..a457771 100644 --- a/core/net/resolvers/ares_resolver.cpp +++ b/core/net/resolvers/ares_resolver.cpp @@ -29,8 +29,8 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "ares_resolver.h" -#include #include "core/loops/channel.h" +#include #ifdef _WIN32 #include #else @@ -67,14 +67,14 @@ bool Resolver::isCAresUsed() { AresResolver::LibraryInitializer::LibraryInitializer() { ares_library_init(ARES_LIB_INIT_ALL); } + AresResolver::LibraryInitializer::~LibraryInitializer() { ares_library_cleanup(); } AresResolver::LibraryInitializer AresResolver::libraryInitializer_; -std::shared_ptr Resolver::newResolver(EventLoop *loop, - size_t timeout) { +std::shared_ptr Resolver::newResolver(EventLoop *loop, size_t timeout) { return std::make_shared(loop, timeout); } @@ -84,6 +84,7 @@ AresResolver::AresResolver(EventLoop *loop, size_t timeout) : loop_ = getLoop(); } } + void AresResolver::init() { if (!ctx_) { struct ares_options options; @@ -108,14 +109,16 @@ void AresResolver::init() { this); } } + AresResolver::~AresResolver() { if (ctx_) ares_destroy(ctx_); } -void AresResolver::resolveInLoop(const std::string &hostname, - const Callback &cb) { +void AresResolver::resolveInLoop(const std::string &hostname, const Callback &cb) { + loop_->assertInLoopThread(); + #ifdef _WIN32 if (hostname == "localhost") { const static InetAddress localhost_{ "127.0.0.1", 0 }; @@ -123,20 +126,17 @@ void AresResolver::resolveInLoop(const std::string &hostname, return; } #endif + init(); QueryData *queryData = new QueryData(this, cb, hostname); - ares_gethostbyname(ctx_, - hostname.c_str(), - AF_INET, - &AresResolver::ares_hostcallback_, - queryData); + ares_gethostbyname(ctx_, hostname.c_str(), AF_INET, &AresResolver::ares_hostcallback_, queryData); struct timeval tv; struct timeval *tvp = ares_timeout(ctx_, NULL, &tv); double timeout = getSeconds(tvp); + // LOG_DEBUG << "timeout " << timeout << " active " << timerActive_; if (!timerActive_ && timeout >= 0.0) { - loop_->runAfter(timeout, - std::bind(&AresResolver::onTimer, shared_from_this())); + loop_->runAfter(timeout, std::bind(&AresResolver::onTimer, shared_from_this())); timerActive_ = true; } return; @@ -161,18 +161,18 @@ void AresResolver::onTimer() { } } -void AresResolver::onQueryResult(int status, - struct hostent *result, - const std::string &hostname, - const Callback &callback) { +void AresResolver::onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback) { + LOG_TRACE << "onQueryResult " << status; struct sockaddr_in addr; memset(&addr, 0, sizeof addr); addr.sin_family = AF_INET; addr.sin_port = 0; + if (result) { addr.sin_addr = *reinterpret_cast(result->h_addr); } + InetAddress inet(addr); { std::lock_guard lock(globalMutex()); @@ -180,6 +180,7 @@ void AresResolver::onQueryResult(int status, addrItem.first = addr.sin_addr; addrItem.second = Date::date(); } + callback(inet); } @@ -198,6 +199,7 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) { loop_->assertInLoopThread(); ChannelList::iterator it = channels_.find(sockfd); assert(it != channels_.end()); + if (read) { // update // if (write) { } else { } @@ -209,17 +211,12 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) { } } -void AresResolver::ares_hostcallback_(void *data, - int status, - int timeouts, - struct hostent *hostent) { +void AresResolver::ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent) { (void)timeouts; QueryData *query = static_cast(data); - query->owner_->onQueryResult(status, - hostent, - query->hostname_, - query->callback_); + query->owner_->onQueryResult(status, hostent, query->hostname_, query->callback_); + delete query; } @@ -242,6 +239,7 @@ void AresResolver::ares_sock_statecallback_(void *data, #endif int read, int write) { + LOG_TRACE << "sockfd=" << sockfd << " read=" << read << " write=" << write; static_cast(data)->onSockStateChange(sockfd, read, write); } diff --git a/core/net/resolvers/ares_resolver.h b/core/net/resolvers/ares_resolver.h index 9b0cbc3..f77a645 100644 --- a/core/net/resolvers/ares_resolver.h +++ b/core/net/resolvers/ares_resolver.h @@ -36,14 +36,15 @@ #include #include +// Resolver will be a ref + extern "C" { struct hostent; struct ares_channeldata; using ares_channel = struct ares_channeldata *; } -class AresResolver : public Resolver, - public std::enable_shared_from_this { +class AresResolver : public Resolver, public std::enable_shared_from_this { protected: AresResolver(const AresResolver &) = delete; AresResolver &operator=(const AresResolver &) = delete; @@ -55,9 +56,9 @@ public: AresResolver(EventLoop *loop, size_t timeout); ~AresResolver(); - virtual void resolve(const std::string &hostname, - const Callback &cb) override { + virtual void resolve(const std::string &hostname, const Callback &cb) override { bool cached = false; + InetAddress inet; { std::lock_guard lock(globalMutex()); @@ -76,10 +77,12 @@ public: } } } + if (cached) { cb(inet); return; } + if (loop_->isInLoopThread()) { resolveInLoop(hostname, cb); } else { @@ -94,14 +97,15 @@ private: AresResolver *owner_; Callback callback_; std::string hostname_; - QueryData(AresResolver *o, - const Callback &cb, - const std::string &hostname) : + + QueryData(AresResolver *o, const Callback &cb, const std::string &hostname) : owner_(o), callback_(cb), hostname_(hostname) { } }; + void resolveInLoop(const std::string &hostname, const Callback &cb); void init(); + EventLoop *loop_; ares_channel ctx_{ nullptr }; bool timerActive_{ false }; @@ -109,36 +113,33 @@ private: ChannelList channels_; static std::unordered_map > & + globalCache() { - static std::unordered_map > - dnsCache; + static std::unordered_map > dnsCache; return dnsCache; } + static std::mutex &globalMutex() { static std::mutex mutex_; return mutex_; } + static EventLoop *getLoop() { static EventLoopThread loopThread; loopThread.run(); return loopThread.getLoop(); } + const size_t timeout_{ 60 }; void onRead(int sockfd); void onTimer(); - void onQueryResult(int status, - struct hostent *result, - const std::string &hostname, - const Callback &callback); + void onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback); void onSockCreate(int sockfd, int type); void onSockStateChange(int sockfd, bool read, bool write); - static void ares_hostcallback_(void *data, - int status, - int timeouts, - struct hostent *hostent); + static void ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent); + #ifdef _WIN32 static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data); #else @@ -152,9 +153,11 @@ private: #endif int read, int write); + struct LibraryInitializer { LibraryInitializer(); ~LibraryInitializer(); }; + static LibraryInitializer libraryInitializer_; }; diff --git a/core/net/resolvers/normal_resolver.cpp b/core/net/resolvers/normal_resolver.cpp index 40afd92..0354506 100644 --- a/core/net/resolvers/normal_resolver.cpp +++ b/core/net/resolvers/normal_resolver.cpp @@ -44,9 +44,11 @@ std::shared_ptr Resolver::newResolver(EventLoop *, size_t timeout) { return std::make_shared(timeout); } + bool Resolver::isCAresUsed() { return false; } + void NormalResolver::resolve(const std::string &hostname, const Callback &callback) { { diff --git a/core/net/resolvers/normal_resolver.h b/core/net/resolvers/normal_resolver.h index e6909cc..5d3265e 100644 --- a/core/net/resolvers/normal_resolver.h +++ b/core/net/resolvers/normal_resolver.h @@ -35,10 +35,12 @@ #include #include +//Resolver will be a ref + constexpr size_t kResolveBufferLength{ 16 * 1024 }; -class NormalResolver : public Resolver, - public std::enable_shared_from_this { +class NormalResolver : public Resolver, public std::enable_shared_from_this { + protected: NormalResolver(const NormalResolver &) = delete; NormalResolver &operator=(const NormalResolver &) = delete; @@ -56,25 +58,23 @@ public: } private: - static std::unordered_map > & - globalCache() { - static std::unordered_map< - std::string, - std::pair > - dnsCache_; + static std::unordered_map > &globalCache() { + static std::unordered_map > dnsCache_; return dnsCache_; } + static std::mutex &globalMutex() { static std::mutex mutex_; return mutex_; } + static ConcurrentTaskQueue &concurrentTaskQueue() { static ConcurrentTaskQueue queue( std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(), "Dns Queue"); return queue; } + const size_t timeout_; std::vector resolveBuffer_; }; diff --git a/core/net/socket.cpp b/core/net/socket.cpp index f44225f..c3669e8 100644 --- a/core/net/socket.cpp +++ b/core/net/socket.cpp @@ -29,9 +29,9 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "socket.h" +#include "core/log/logger.h" #include #include -#include "core/log/logger.h" #ifdef _WIN32 #include #else @@ -39,7 +39,6 @@ #include #endif - bool Socket::isSelfConnect(int sockfd) { struct sockaddr_in6 localaddr = getLocalAddr(sockfd); struct sockaddr_in6 peeraddr = getPeerAddr(sockfd); @@ -75,6 +74,7 @@ void Socket::bindAddress(const InetAddress &localaddr) { exit(1); } } + void Socket::listen() { assert(sockFd_ > 0); int ret = ::listen(sockFd_, SOMAXCONN); @@ -83,6 +83,7 @@ void Socket::listen() { exit(1); } } + int Socket::accept(InetAddress *peeraddr) { struct sockaddr_in6 addr6; memset(&addr6, 0, sizeof(addr6)); @@ -102,6 +103,7 @@ int Socket::accept(InetAddress *peeraddr) { } return connfd; } + void Socket::closeWrite() { #ifndef _WIN32 if (::shutdown(sockFd_, SHUT_WR) < 0) @@ -112,6 +114,7 @@ void Socket::closeWrite() { LOG_SYSERR << "sockets::shutdownWrite"; } } + int Socket::read(char *buffer, uint64_t len) { #ifndef _WIN32 return ::read(sockFd_, buffer, len); diff --git a/core/net/socket.h b/core/net/socket.h index 70d6b82..c926e3b 100644 --- a/core/net/socket.h +++ b/core/net/socket.h @@ -32,8 +32,8 @@ #pragma once -#include "core/net/inet_address.h" #include "core/log/logger.h" +#include "core/net/inet_address.h" #include #ifndef _WIN32 #include @@ -41,7 +41,6 @@ #include class Socket { - protected: Socket(const Socket &) = delete; Socket &operator=(const Socket &) = delete; @@ -63,6 +62,7 @@ public: LOG_SYSERR << "sockets::createNonblockingOrDie"; exit(1); } + LOG_TRACE << "sock=" << sock; return sock; } @@ -85,15 +85,9 @@ public: static int connect(int sockfd, const InetAddress &addr) { if (addr.isIpV6()) - return ::connect(sockfd, - addr.getSockAddr(), - static_cast( - sizeof(struct sockaddr_in6))); + return ::connect(sockfd, addr.getSockAddr(), static_cast(sizeof(struct sockaddr_in6))); else - return ::connect(sockfd, - addr.getSockAddr(), - static_cast( - sizeof(struct sockaddr_in))); + return ::connect(sockfd, addr.getSockAddr(), static_cast(sizeof(struct sockaddr_in))); } static bool isSelfConnect(int sockfd); @@ -101,7 +95,9 @@ public: explicit Socket(int sockfd) : sockFd_(sockfd) { } + ~Socket(); + /// abort if address in use void bindAddress(const InetAddress &localaddr); /// abort if address in use @@ -109,9 +105,11 @@ public: int accept(InetAddress *peeraddr); void closeWrite(); int read(char *buffer, uint64_t len); + int fd() { return sockFd_; } + static struct sockaddr_in6 getLocalAddr(int sockfd); static struct sockaddr_in6 getPeerAddr(int sockfd); diff --git a/core/net/tcp_client.cpp b/core/net/tcp_client.cpp index 2df6fc2..2d587b7 100644 --- a/core/net/tcp_client.cpp +++ b/core/net/tcp_client.cpp @@ -54,9 +54,7 @@ TcpClient::IgnoreSigPipe TcpClient::initObj; #endif static void defaultConnectionCallback(const TcpConnectionPtr &conn) { - LOG_TRACE << conn->localAddr().toIpPort() << " -> " - << conn->peerAddr().toIpPort() << " is " - << (conn->connected() ? "UP" : "DOWN"); + LOG_TRACE << conn->localAddr().toIpPort() << " -> " << conn->peerAddr().toIpPort() << " is " << (conn->connected() ? "UP" : "DOWN"); // do not call conn->forceClose(), because some users want to register // message callback only. } @@ -65,9 +63,7 @@ static void defaultMessageCallback(const TcpConnectionPtr &, MsgBuffer *buf) { buf->retrieveAll(); } -TcpClient::TcpClient(EventLoop *loop, - const InetAddress &serverAddr, - const std::string &nameArg) : +TcpClient::TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg) : loop_(loop), connector_(new Connector(loop, serverAddr, false)), name_(nameArg), @@ -75,23 +71,27 @@ TcpClient::TcpClient(EventLoop *loop, messageCallback_(defaultMessageCallback), retry_(false), connect_(true) { - connector_->setNewConnectionCallback( - std::bind(&TcpClient::newConnection, this, _1)); + + connector_->setNewConnectionCallback(std::bind(&TcpClient::newConnection, this, _1)); + connector_->setErrorCallback([this]() { if (connectionErrorCallback_) { connectionErrorCallback_(); } }); + LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector "; } TcpClient::~TcpClient() { LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector "; + TcpConnectionImplPtr conn; { std::lock_guard lock(mutex_); conn = std::dynamic_pointer_cast(connection_); } + if (conn) { assert(loop_ == conn->getLoop()); // TODO: not 100% safe, if we are in different thread @@ -104,6 +104,7 @@ TcpClient::~TcpClient() { }); }); }); + conn->forceClose(); } else { /// TODO need test in this condition @@ -142,39 +143,34 @@ void TcpClient::newConnection(int sockfd) { // TODO poll with zero timeout to double confirm the new connection // TODO use make_shared if necessary std::shared_ptr conn; + if (sslCtxPtr_) { #ifdef USE_OPENSSL - conn = std::make_shared(loop_, - sockfd, - localAddr, - peerAddr, - sslCtxPtr_, - false, - validateCert_, - SSLHostName_); + conn = std::make_shared(loop_, sockfd, localAddr, peerAddr, sslCtxPtr_, false, validateCert_, SSLHostName_); #else LOG_FATAL << "OpenSSL is not found in your system!"; abort(); #endif } else { - conn = std::make_shared(loop_, - sockfd, - localAddr, - peerAddr); + conn = std::make_shared(loop_, sockfd, localAddr, peerAddr); } + conn->setConnectionCallback(connectionCallback_); conn->setRecvMsgCallback(messageCallback_); conn->setWriteCompleteCallback(writeCompleteCallback_); conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1)); + { std::lock_guard lock(mutex_); connection_ = conn; } + conn->setSSLErrorCallback([this](SSLError err) { if (sslErrorCallback_) { sslErrorCallback_(err); } }); + conn->connectEstablished(); } @@ -188,30 +184,23 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn) { connection_.reset(); } - loop_->queueInLoop( - std::bind(&TcpConnectionImpl::connectDestroyed, - std::dynamic_pointer_cast(conn))); + loop_->queueInLoop(std::bind(&TcpConnectionImpl::connectDestroyed, std::dynamic_pointer_cast(conn))); + if (retry_ && connect_) { - LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " - << connector_->serverAddress().toIpPort(); + LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " << connector_->serverAddress().toIpPort(); connector_->restart(); } } -void TcpClient::enableSSL( - bool useOldTLS, - bool validateCert, - std::string hostname, - const std::vector > &sslConfCmds) { +void TcpClient::enableSSL(bool useOldTLS, bool validateCert, std::string hostname, const std::vector > &sslConfCmds) { #ifdef USE_OPENSSL /* Create a new OpenSSL context */ sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds); + validateCert_ = validateCert; + if (!hostname.empty()) { - std::transform(hostname.begin(), - hostname.end(), - hostname.begin(), - tolower); + std::transform(hostname.begin(), hostname.end(), hostname.begin(), tolower); SSLHostName_ = std::move(hostname); } diff --git a/core/net/tcp_client.h b/core/net/tcp_client.h index 23dbd32..61b442b 100644 --- a/core/net/tcp_client.h +++ b/core/net/tcp_client.h @@ -42,10 +42,7 @@ class Connector; using ConnectorPtr = std::shared_ptr; class SSLContext; -/** - * @brief This class represents a TCP client. - * - */ + class TcpClient { protected: TcpClient(const TcpClient &) = delete; @@ -55,88 +52,35 @@ protected: TcpClient &operator=(TcpClient &&) noexcept(true) = default; public: - /** - * @brief Construct a new TCP client instance. - * - * @param loop The event loop in which the client runs. - * @param serverAddr The address of the server. - * @param nameArg The name of the client. - */ - TcpClient(EventLoop *loop, - const InetAddress &serverAddr, - const std::string &nameArg); + TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg); ~TcpClient(); - /** - * @brief Connect to the server. - * - */ void connect(); - - /** - * @brief Disconnect from the server. - * - */ void disconnect(); - /** - * @brief Stop connecting to the server. - * - */ void stop(); - /** - * @brief Get the TCP connection to the server. - * - * @return TcpConnectionPtr - */ TcpConnectionPtr connection() const { std::lock_guard lock(mutex_); return connection_; } - /** - * @brief Get the event loop. - * - * @return EventLoop* - */ EventLoop *getLoop() const { return loop_; } - /** - * @brief Check whether the client re-connect to the server. - * - * @return true - * @return false - */ bool retry() const { return retry_; } - /** - * @brief Enable retrying. - * - */ void enableRetry() { retry_ = true; } - /** - * @brief Get the name of the client. - * - * @return const std::string& - */ const std::string &name() const { return name_; } - /** - * @brief Set the connection callback. - * - * @param cb The callback is called when the connection to the server is - * established or closed. - */ void setConnectionCallback(const ConnectionCallback &cb) { connectionCallback_ = cb; } @@ -144,37 +88,18 @@ public: connectionCallback_ = std::move(cb); } - /** - * @brief Set the connection error callback. - * - * @param cb The callback is called when an error occurs during connecting - * to the server. - */ void setConnectionErrorCallback(const ConnectionErrorCallback &cb) { connectionErrorCallback_ = cb; } - /** - * @brief Set the message callback. - * - * @param cb The callback is called when some data is received from the - * server. - */ void setMessageCallback(const RecvMessageCallback &cb) { messageCallback_ = cb; } void setMessageCallback(RecvMessageCallback &&cb) { messageCallback_ = std::move(cb); } - /// Set write complete callback. - /// Not thread safe. - /** - * @brief Set the write complete callback. - * - * @param cb The callback is called when data to send is written to the - * socket. - */ + /// Not thread safe. void setWriteCompleteCallback(const WriteCompleteCallback &cb) { writeCompleteCallback_ = cb; } @@ -182,10 +107,6 @@ public: writeCompleteCallback_ = std::move(cb); } - /** - * @brief Set the callback for errors of SSL - * @param cb The callback is called when an SSL error occurs. - */ void setSSLErrorCallback(const SSLErrorCallback &cb) { sslErrorCallback_ = cb; } @@ -193,24 +114,7 @@ public: sslErrorCallback_ = std::move(cb); } - /** - * @brief Enable SSL encryption. - * @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the - * client. - * @param validateCert If true, we try to validate if the peer's SSL cert - * is valid. - * @param hostname The server hostname for SNI. If it is empty, the SNI is - * not used. - * @param sslConfCmds The commands used to call the SSL_CONF_cmd function in - * OpenSSL. - * @note It's well known that TLS 1.0 and 1.1 are not considered secure in - * 2020. And it's a good practice to only use TLS 1.2 and above. - */ - void enableSSL(bool useOldTLS = false, - bool validateCert = true, - std::string hostname = "", - const std::vector > - &sslConfCmds = {}); + void enableSSL(bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector > &sslConfCmds = {}); private: /// Not thread safe, but in loop @@ -234,6 +138,7 @@ private: std::shared_ptr sslCtxPtr_; bool validateCert_{ false }; std::string SSLHostName_; + #ifndef _WIN32 class IgnoreSigPipe { public: diff --git a/core/net/tcp_connection.h b/core/net/tcp_connection.h index ee9e959..c5e5690 100644 --- a/core/net/tcp_connection.h +++ b/core/net/tcp_connection.h @@ -39,26 +39,13 @@ #include class SSLContext; -std::shared_ptr newSSLServerContext( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS = false, - const std::vector > &sslConfCmds = {}); -/** - * @brief This class represents a TCP connection. - * - */ +std::shared_ptr newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector > &sslConfCmds = {}); + class TcpConnection { public: TcpConnection() = default; virtual ~TcpConnection(){}; - /** - * @brief Send some data to the peer. - * - * @param msg - * @param len - */ virtual void send(const char *msg, size_t len) = 0; virtual void send(const void *msg, size_t len) = 0; virtual void send(const std::string &msg) = 0; @@ -68,95 +55,27 @@ public: virtual void send(const std::shared_ptr &msgPtr) = 0; virtual void send(const std::shared_ptr &msgPtr) = 0; - /** - * @brief Send a file to the peer. - * - * @param fileName - * @param offset - * @param length - */ - virtual void sendFile(const char *fileName, - size_t offset = 0, - size_t length = 0) = 0; + virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) = 0; - /** - * @brief Get the local address of the connection. - * - * @return const InetAddress& - */ virtual const InetAddress &localAddr() const = 0; - - /** - * @brief Get the remote address of the connection. - * - * @return const InetAddress& - */ virtual const InetAddress &peerAddr() const = 0; - /** - * @brief Return true if the connection is established. - * - * @return true - * @return false - */ virtual bool connected() const = 0; - - /** - * @brief Return false if the connection is established. - * - * @return true - * @return false - */ virtual bool disconnected() const = 0; - /** - * @brief Get the buffer in which the received data stored. - * - * @return MsgBuffer* - */ virtual MsgBuffer *getRecvBuffer() = 0; - /** - * @brief Set the high water mark callback - * - * @param cb The callback is called when the data in sending buffer is - * larger than the water mark. - * @param markLen The water mark in bytes. - */ - virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, - size_t markLen) = 0; - /** - * @brief Set the TCP_NODELAY option to the socket. - * - * @param on - */ + virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, size_t markLen) = 0; + virtual void setTcpNoDelay(bool on) = 0; - /** - * @brief Shutdown the connection. - * @note This method only closes the writing direction. - */ virtual void shutdown() = 0; - - /** - * @brief Close the connection forcefully. - * - */ virtual void forceClose() = 0; - /** - * @brief Get the event loop in which the connection I/O is handled. - * - * @return EventLoop* - */ + virtual EventLoop *getLoop() = 0; - /** - * @brief Set the custom data on the connection. - * - * @param context - */ void setContext(const std::shared_ptr &context) { contextPtr_ = context; } @@ -164,98 +83,30 @@ public: contextPtr_ = std::move(context); } - /** - * @brief Get the custom data from the connection. - * - * @tparam T - * @return std::shared_ptr - */ template std::shared_ptr getContext() const { return std::static_pointer_cast(contextPtr_); } - /** - * @brief Return true if the custom data is set by user. - * - * @return true - * @return false - */ bool hasContext() const { return (bool)contextPtr_; } - /** - * @brief Clear the custom data. - * - */ void clearContext() { contextPtr_.reset(); } - /** - * @brief Call this method to avoid being kicked off by TcpServer, refer to - * the kickoffIdleConnections method in the TcpServer class. - * - */ virtual void keepAlive() = 0; - - /** - * @brief Return true if the keepAlive() method is called. - * - * @return true - * @return false - */ virtual bool isKeepAlive() = 0; - /** - * @brief Return the number of bytes sent - * - * @return size_t - */ virtual size_t bytesSent() const = 0; - - /** - * @brief Return the number of bytes received. - * - * @return size_t - */ virtual size_t bytesReceived() const = 0; - /** - * @brief Check whether the connection is SSL encrypted. - * - * @return true - * @return false - */ virtual bool isSSLConnection() const = 0; - /** - * @brief Start the SSL encryption on the connection (as a client). - * - * @param callback The callback is called when the SSL connection is - * established. - * @param hostname The server hostname for SNI. If it is empty, the SNI is - * not used. - * @param sslConfCmds The commands used to call the SSL_CONF_cmd function in - * OpenSSL. - */ - virtual void startClientEncryption( - std::function callback, - bool useOldTLS = false, - bool validateCert = true, - std::string hostname = "", - const std::vector > &sslConfCmds = {}) = 0; + virtual void startClientEncryption(std::function callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector > &sslConfCmds = {}) = 0; - /** - * @brief Start the SSL encryption on the connection (as a server). - * - * @param ctx The SSL context. - * @param callback The callback is called when the SSL connection is - * established. - */ - virtual void startServerEncryption(const std::shared_ptr &ctx, - std::function callback) = 0; + virtual void startServerEncryption(const std::shared_ptr &ctx, std::function callback) = 0; protected: bool validateCert_ = false; diff --git a/core/net/tcp_server.cpp b/core/net/tcp_server.cpp index f03d347..a0e13a7 100644 --- a/core/net/tcp_server.cpp +++ b/core/net/tcp_server.cpp @@ -28,20 +28,16 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#include "core/loops/acceptor.h" -#include "core/net/connections/tcp_connection_impl.h" #include "core/net/tcp_server.h" #include "core/log/logger.h" +#include "core/loops/acceptor.h" +#include "core/net/connections/tcp_connection_impl.h" #include #include using namespace std::placeholders; -TcpServer::TcpServer(EventLoop *loop, - const InetAddress &address, - const std::string &name, - bool reUseAddr, - bool reUsePort) : +TcpServer::TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr, bool reUsePort) : loop_(loop), acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)), serverName_(name), @@ -50,8 +46,8 @@ TcpServer::TcpServer(EventLoop *loop, << " bytes]"; buffer->retrieveAll(); }) { - acceptorPtr_->setNewConnectionCallback( - std::bind(&TcpServer::newConnection, this, _1, _2)); + + acceptorPtr_->setNewConnectionCallback(std::bind(&TcpServer::newConnection, this, _1, _2)); } TcpServer::~TcpServer() { @@ -69,14 +65,20 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) { // LOG_TRACE<<"vector size:"<assertInLoopThread(); + EventLoop *ioLoop = NULL; if (loopPoolPtr_ && loopPoolPtr_->size() > 0) { ioLoop = loopPoolPtr_->getNextLoop(); } - if (ioLoop == NULL) + + if (ioLoop == NULL) { ioLoop = loop_; + } + std::shared_ptr newPtr; + if (sslCtxPtr_) { #ifdef USE_OPENSSL newPtr = std::make_shared( @@ -90,14 +92,14 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) { abort(); #endif } else { - newPtr = std::make_shared( - ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer); + newPtr = std::make_shared(ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer); } if (idleTimeout_ > 0) { assert(timingWheelMap_[ioLoop]); newPtr->enableKickingOff(idleTimeout_, timingWheelMap_[ioLoop]); } + newPtr->setRecvMsgCallback(recvMessageCallback_); newPtr->setConnectionCallback( @@ -105,11 +107,13 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) { if (connectionCallback_) connectionCallback_(connectionPtr); }); + newPtr->setWriteCompleteCallback( [this](const TcpConnectionPtr &connectionPtr) { if (writeCompleteCallback_) writeCompleteCallback_(connectionPtr); }); + newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1)); connSet_.insert(newPtr); newPtr->connectEstablished(); @@ -143,6 +147,7 @@ void TcpServer::start() { acceptorPtr_->listen(); }); } + void TcpServer::stop() { loop_->runInLoop([this]() { acceptorPtr_.reset(); }); for (auto connection : connSet_) { @@ -159,6 +164,7 @@ void TcpServer::stop() { f.get(); } } + void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) { LOG_TRACE << "connectionClosed"; // loop_->assertInLoopThread(); @@ -179,11 +185,7 @@ const InetAddress &TcpServer::address() const { return acceptorPtr_->addr(); } -void TcpServer::enableSSL( - const std::string &certPath, - const std::string &keyPath, - bool useOldTLS, - const std::vector > &sslConfCmds) { +void TcpServer::enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector > &sslConfCmds) { #ifdef USE_OPENSSL /* Create a new OpenSSL context */ sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds); diff --git a/core/net/tcp_server.h b/core/net/tcp_server.h index 5cf27c7..6fbb7b4 100644 --- a/core/net/tcp_server.h +++ b/core/net/tcp_server.h @@ -43,10 +43,7 @@ class Acceptor; class SSLContext; -/** - * @brief This class represents a TCP server. - * - */ + class TcpServer { protected: TcpServer(const TcpServer &) = delete; @@ -56,53 +53,18 @@ protected: TcpServer &operator=(TcpServer &&) noexcept(true) = default; public: - /** - * @brief Construct a new TCP server instance. - * - * @param loop The event loop in which the acceptor of the server is - * handled. - * @param address The address of the server. - * @param name The name of the server. - * @param reUseAddr The SO_REUSEADDR option. - * @param reUsePort The SO_REUSEPORT option. - */ - TcpServer(EventLoop *loop, - const InetAddress &address, - const std::string &name, - bool reUseAddr = true, - bool reUsePort = true); + TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr = true, bool reUsePort = true); ~TcpServer(); - /** - * @brief Start the server. - * - */ void start(); - - /** - * @brief Stop the server. - * - */ void stop(); - /** - * @brief Set the number of event loops in which the I/O of connections to - * the server is handled. - * - * @param num - */ void setIoLoopNum(size_t num) { assert(!started_); loopPoolPtr_ = std::make_shared(num); loopPoolPtr_->start(); } - /** - * @brief Set the event loops pool in which the I/O of connections to - * the server is handled. - * - * @param pool - */ void setIoLoopThreadPool(const std::shared_ptr &pool) { assert(pool->size() > 0); assert(!started_); @@ -110,12 +72,6 @@ public: loopPoolPtr_->start(); } - /** - * @brief Set the message callback. - * - * @param cb The callback is called when some data is received on a - * connection to the server. - */ void setRecvMessageCallback(const RecvMessageCallback &cb) { recvMessageCallback_ = cb; } @@ -123,12 +79,6 @@ public: recvMessageCallback_ = std::move(cb); } - /** - * @brief Set the connection callback. - * - * @param cb The callback is called when a connection is established or - * closed. - */ void setConnectionCallback(const ConnectionCallback &cb) { connectionCallback_ = cb; } @@ -136,12 +86,6 @@ public: connectionCallback_ = std::move(cb); } - /** - * @brief Set the write complete callback. - * - * @param cb The callback is called when data to send is written to the - * socket of a connection. - */ void setWriteCompleteCallback(const WriteCompleteCallback &cb) { writeCompleteCallback_ = cb; } @@ -149,53 +93,21 @@ public: writeCompleteCallback_ = std::move(cb); } - /** - * @brief Get the name of the server. - * - * @return const std::string& - */ const std::string &name() const { return serverName_; } - /** - * @brief Get the IP and port string of the server. - * - * @return const std::string - */ const std::string ipPort() const; - - /** - * @brief Get the address of the server. - * - * @return const InetAddress& - */ const InetAddress &address() const; - /** - * @brief Get the event loop of the server. - * - * @return EventLoop* - */ EventLoop *getLoop() const { return loop_; } - /** - * @brief Get the I/O event loops of the server. - * - * @return std::vector - */ std::vector getIoLoops() const { return loopPoolPtr_->getLoops(); } - /** - * @brief An idle connection is a connection that has no read or write, kick - * off it after timeout seconds. - * - * @param timeout - */ void kickoffIdleConnections(size_t timeout) { loop_->runInLoop([this, timeout]() { assert(!started_); @@ -203,23 +115,12 @@ public: }); } - /** - * @brief Enable SSL encryption. - * - * @param certPath The path of the certificate file. - * @param keyPath The path of the private key file. - * @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the - * server. - * @param sslConfCmds The commands used to call the SSL_CONF_cmd function in - * OpenSSL. - * @note It's well known that TLS 1.0 and 1.1 are not considered secure in - * 2020. And it's a good practice to only use TLS 1.2 and above. - */ - void enableSSL(const std::string &certPath, - const std::string &keyPath, - bool useOldTLS = false, - const std::vector > - &sslConfCmds = {}); + // certPath The path of the certificate file. + // keyPath The path of the private key file. + // useOldTLS If true, the TLS 1.0 and 1.1 are supported by the server. + // sslConfCmds The commands used to call the SSL_CONF_cmd function in OpenSSL. + // Note: It's well known that TLS 1.0 and 1.1 are not considered secure in 2020. And it's a good practice to only use TLS 1.2 and above. + void enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector > &sslConfCmds = {}); private: EventLoop *loop_; @@ -236,6 +137,7 @@ private: std::map > timingWheelMap_; void connectionClosed(const TcpConnectionPtr &connectionPtr); std::shared_ptr loopPoolPtr_; + #ifndef _WIN32 class IgnoreSigPipe { public: @@ -247,6 +149,7 @@ private: IgnoreSigPipe initObj; #endif + bool started_{ false }; // OpenSSL SSL context Object;