rcpp_framework/modules/drogon/trantor/net/TcpClient.cc

207 lines
5.1 KiB
C++

// Copyright 2010, Shuo Chen. All rights reserved.
// http://code.google.com/p/muduo/
//
// Use of this source code is governed by a BSD-style license
// that can be found in the License file.
// Author: Shuo Chen (chenshuo at chenshuo dot com)
//
// Taken from muduo and modified by an tao
#include <trantor/net/TcpClient.h>
#include "Connector.h"
#include "inner/TcpConnectionImpl.h"
#include <trantor/net/EventLoop.h>
#include <trantor/utils/Logger.h>
#include <algorithm>
#include <functional>
#include "Socket.h"
#include <stdio.h> // snprintf
using namespace trantor;
using namespace std::placeholders;
namespace trantor {
// void removeConnector(const ConnectorPtr &)
// {
// // connector->
// }
#ifndef _WIN32
TcpClient::IgnoreSigPipe TcpClient::initObj;
#endif
static void defaultConnectionCallback(const TcpConnectionPtr &conn) {
LOG_TRACE << conn->localAddr().toIpPort() << " -> "
<< conn->peerAddr().toIpPort() << " is "
<< (conn->connected() ? "UP" : "DOWN");
// do not call conn->forceClose(), because some users want to register
// message callback only.
}
static void defaultMessageCallback(const TcpConnectionPtr &, MsgBuffer *buf) {
buf->retrieveAll();
}
} // namespace trantor
TcpClient::TcpClient(EventLoop *loop,
const InetAddress &serverAddr,
const std::string &nameArg) :
loop_(loop),
connector_(new Connector(loop, serverAddr, false)),
name_(nameArg),
connectionCallback_(defaultConnectionCallback),
messageCallback_(defaultMessageCallback),
retry_(false),
connect_(true) {
connector_->setNewConnectionCallback(
std::bind(&TcpClient::newConnection, this, _1));
connector_->setErrorCallback([this]() {
if (connectionErrorCallback_) {
connectionErrorCallback_();
}
});
LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector ";
}
TcpClient::~TcpClient() {
LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector ";
TcpConnectionImplPtr conn;
{
std::lock_guard<std::mutex> lock(mutex_);
conn = std::dynamic_pointer_cast<TcpConnectionImpl>(connection_);
}
if (conn) {
assert(loop_ == conn->getLoop());
// TODO: not 100% safe, if we are in different thread
auto loop = loop_;
loop_->runInLoop([conn, loop]() {
conn->setCloseCallback([loop](const TcpConnectionPtr &connPtr) {
loop->queueInLoop([connPtr]() {
static_cast<TcpConnectionImpl *>(connPtr.get())
->connectDestroyed();
});
});
});
conn->forceClose();
} else {
/// TODO need test in this condition
connector_->stop();
}
}
void TcpClient::connect() {
// TODO: check state
LOG_TRACE << "TcpClient::connect[" << name_ << "] - connecting to "
<< connector_->serverAddress().toIpPort();
connect_ = true;
connector_->start();
}
void TcpClient::disconnect() {
connect_ = false;
{
std::lock_guard<std::mutex> lock(mutex_);
if (connection_) {
connection_->shutdown();
}
}
}
void TcpClient::stop() {
connect_ = false;
connector_->stop();
}
void TcpClient::newConnection(int sockfd) {
loop_->assertInLoopThread();
InetAddress peerAddr(Socket::getPeerAddr(sockfd));
InetAddress localAddr(Socket::getLocalAddr(sockfd));
// TODO poll with zero timeout to double confirm the new connection
// TODO use make_shared if necessary
std::shared_ptr<TcpConnectionImpl> conn;
if (sslCtxPtr_) {
#ifdef USE_OPENSSL
conn = std::make_shared<TcpConnectionImpl>(loop_,
sockfd,
localAddr,
peerAddr,
sslCtxPtr_,
false,
validateCert_,
SSLHostName_);
#else
LOG_FATAL << "OpenSSL is not found in your system!";
abort();
#endif
} else {
conn = std::make_shared<TcpConnectionImpl>(loop_,
sockfd,
localAddr,
peerAddr);
}
conn->setConnectionCallback(connectionCallback_);
conn->setRecvMsgCallback(messageCallback_);
conn->setWriteCompleteCallback(writeCompleteCallback_);
conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1));
{
std::lock_guard<std::mutex> lock(mutex_);
connection_ = conn;
}
conn->setSSLErrorCallback([this](SSLError err) {
if (sslErrorCallback_) {
sslErrorCallback_(err);
}
});
conn->connectEstablished();
}
void TcpClient::removeConnection(const TcpConnectionPtr &conn) {
loop_->assertInLoopThread();
assert(loop_ == conn->getLoop());
{
std::lock_guard<std::mutex> lock(mutex_);
assert(connection_ == conn);
connection_.reset();
}
loop_->queueInLoop(
std::bind(&TcpConnectionImpl::connectDestroyed,
std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
if (retry_ && connect_) {
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to "
<< connector_->serverAddress().toIpPort();
connector_->restart();
}
}
void TcpClient::enableSSL(
bool useOldTLS,
bool validateCert,
std::string hostname,
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
#ifdef USE_OPENSSL
/* Create a new OpenSSL context */
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds);
validateCert_ = validateCert;
if (!hostname.empty()) {
std::transform(hostname.begin(),
hostname.end(),
hostname.begin(),
tolower);
SSLHostName_ = std::move(hostname);
}
#else
LOG_FATAL << "OpenSSL is not found in your system!";
abort();
#endif
}