diff options
Diffstat (limited to 'adb/pairing_connection/pairing_server.cpp')
-rw-r--r-- | adb/pairing_connection/pairing_server.cpp | 466 |
1 files changed, 466 insertions, 0 deletions
diff --git a/adb/pairing_connection/pairing_server.cpp b/adb/pairing_connection/pairing_server.cpp new file mode 100644 index 000000000..7218eacf2 --- /dev/null +++ b/adb/pairing_connection/pairing_server.cpp @@ -0,0 +1,466 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adb/pairing/pairing_server.h" + +#include <sys/epoll.h> +#include <sys/eventfd.h> + +#include <atomic> +#include <deque> +#include <iomanip> +#include <mutex> +#include <sstream> +#include <thread> +#include <tuple> +#include <unordered_map> +#include <variant> +#include <vector> + +#include <adb/crypto/rsa_2048_key.h> +#include <adb/crypto/x509_generator.h> +#include <adb/pairing/pairing_connection.h> +#include <android-base/logging.h> +#include <android-base/parsenetaddress.h> +#include <android-base/thread_annotations.h> +#include <android-base/unique_fd.h> +#include <cutils/sockets.h> + +#include "internal/constants.h" + +using android::base::ScopedLockAssertion; +using android::base::unique_fd; +using namespace adb::crypto; +using namespace adb::pairing; + +// The implementation has two background threads running: one to handle and +// accept any new pairing connection requests (socket accept), and the other to +// handle connection events (connection started, connection finished). +struct PairingServerCtx { + public: + using Data = std::vector<uint8_t>; + + virtual ~PairingServerCtx(); + + // All parameters must be non-empty. + explicit PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key, uint16_t port); + + // Starts the pairing server. This call is non-blocking. Upon completion, + // if the pairing was successful, then |cb| will be called with the PublicKeyHeader + // containing the info of the trusted peer. Otherwise, |cb| will be + // called with an empty value. Start can only be called once in the lifetime + // of this object. + // + // Returns the port number if PairingServerCtx was successfully started. Otherwise, + // returns 0. + uint16_t Start(pairing_server_result_cb cb, void* opaque); + + private: + // Setup the server socket to accept incoming connections. Returns the + // server port number (> 0 on success). + uint16_t SetupServer(); + // Force stop the server thread. + void StopServer(); + + // handles a new pairing client connection + bool HandleNewClientConnection(int fd) EXCLUDES(conn_mutex_); + + // ======== connection events thread ============= + std::mutex conn_mutex_; + std::condition_variable conn_cv_; + + using FdVal = int; + struct ConnectionDeleter { + void operator()(PairingConnectionCtx* p) { pairing_connection_destroy(p); } + }; + using ConnectionPtr = std::unique_ptr<PairingConnectionCtx, ConnectionDeleter>; + static ConnectionPtr CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key); + using NewConnectionEvent = std::tuple<unique_fd, ConnectionPtr>; + // <fd, PeerInfo.type, PeerInfo.data> + using ConnectionFinishedEvent = std::tuple<FdVal, uint8_t, std::optional<std::string>>; + using ConnectionEvent = std::variant<NewConnectionEvent, ConnectionFinishedEvent>; + // Queue for connections to write into. We have a separate queue to read + // from, in order to minimize the time the server thread is blocked. + std::deque<ConnectionEvent> conn_write_queue_ GUARDED_BY(conn_mutex_); + std::deque<ConnectionEvent> conn_read_queue_; + // Map of fds to their PairingConnections currently running. + std::unordered_map<FdVal, ConnectionPtr> connections_; + + // Two threads launched when starting the pairing server: + // 1) A server thread that waits for incoming client connections, and + // 2) A connection events thread that synchonizes events from all of the + // clients, since each PairingConnection is running in it's own thread. + void StartConnectionEventsThread(); + void StartServerThread(); + + static void PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque); + + std::thread conn_events_thread_; + void ConnectionEventsWorker(); + std::thread server_thread_; + void ServerWorker(); + bool is_terminate_ GUARDED_BY(conn_mutex_) = false; + + enum class State { + Ready, + Running, + Stopped, + }; + State state_ = State::Ready; + Data pswd_; + PeerInfo peer_info_; + Data cert_; + Data priv_key_; + uint16_t port_; + + pairing_server_result_cb cb_; + void* opaque_ = nullptr; + bool got_valid_pairing_ = false; + + static const int kEpollConstSocket = 0; + // Used to break the server thread from epoll_wait + static const int kEpollConstEventFd = 1; + unique_fd epoll_fd_; + unique_fd server_fd_; + unique_fd event_fd_; +}; // PairingServerCtx + +// static +PairingServerCtx::ConnectionPtr PairingServerCtx::CreatePairingConnection(const Data& pswd, + const PeerInfo& peer_info, + const Data& cert, + const Data& priv_key) { + return ConnectionPtr(pairing_connection_server_new(pswd.data(), pswd.size(), &peer_info, + cert.data(), cert.size(), priv_key.data(), + priv_key.size())); +} + +PairingServerCtx::PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key, uint16_t port) + : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) { + CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); +} + +PairingServerCtx::~PairingServerCtx() { + // Since these connections have references to us, let's make sure they + // destruct before us. + if (server_thread_.joinable()) { + StopServer(); + server_thread_.join(); + } + + { + std::lock_guard<std::mutex> lock(conn_mutex_); + is_terminate_ = true; + } + conn_cv_.notify_one(); + if (conn_events_thread_.joinable()) { + conn_events_thread_.join(); + } + + // Notify the cb_ if it hasn't already. + if (!got_valid_pairing_ && cb_ != nullptr) { + cb_(nullptr, opaque_); + } +} + +uint16_t PairingServerCtx::Start(pairing_server_result_cb cb, void* opaque) { + cb_ = cb; + opaque_ = opaque; + + if (state_ != State::Ready) { + LOG(ERROR) << "PairingServerCtx already running or stopped"; + return 0; + } + + port_ = SetupServer(); + if (port_ == 0) { + LOG(ERROR) << "Unable to start PairingServer"; + state_ = State::Stopped; + return 0; + } + LOG(INFO) << "Pairing server started on port " << port_; + + state_ = State::Running; + return port_; +} + +void PairingServerCtx::StopServer() { + if (event_fd_.get() == -1) { + return; + } + uint64_t value = 1; + ssize_t rc = write(event_fd_.get(), &value, sizeof(value)); + if (rc == -1) { + // This can happen if the server didn't start. + PLOG(ERROR) << "write to eventfd failed"; + } else if (rc != sizeof(value)) { + LOG(FATAL) << "write to event returned short (" << rc << ")"; + } +} + +uint16_t PairingServerCtx::SetupServer() { + epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC)); + if (epoll_fd_ == -1) { + PLOG(ERROR) << "failed to create epoll fd"; + return 0; + } + + event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); + if (event_fd_ == -1) { + PLOG(ERROR) << "failed to create eventfd"; + return 0; + } + + server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM)); + if (server_fd_.get() == -1) { + PLOG(ERROR) << "Failed to start pairing connection server"; + return 0; + } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) { + PLOG(ERROR) << "Failed to make server socket cloexec"; + return 0; + } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) { + PLOG(ERROR) << "Failed to make server socket nonblocking"; + return 0; + } + + StartConnectionEventsThread(); + StartServerThread(); + int port = socket_get_local_port(server_fd_.get()); + return (port <= 0 ? 0 : port); +} + +void PairingServerCtx::StartServerThread() { + server_thread_ = std::thread([this]() { ServerWorker(); }); +} + +void PairingServerCtx::StartConnectionEventsThread() { + conn_events_thread_ = std::thread([this]() { ConnectionEventsWorker(); }); +} + +void PairingServerCtx::ServerWorker() { + { + struct epoll_event event; + event.events = EPOLLIN; + event.data.u64 = kEpollConstSocket; + CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event)); + } + + { + struct epoll_event event; + event.events = EPOLLIN; + event.data.u64 = kEpollConstEventFd; + CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event)); + } + + while (true) { + struct epoll_event events[2]; + int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1)); + if (rc == -1) { + PLOG(ERROR) << "epoll_wait failed"; + return; + } else if (rc == 0) { + LOG(ERROR) << "epoll_wait returned 0"; + return; + } + + for (int i = 0; i < rc; ++i) { + struct epoll_event& event = events[i]; + switch (event.data.u64) { + case kEpollConstSocket: + HandleNewClientConnection(server_fd_.get()); + break; + case kEpollConstEventFd: + uint64_t dummy; + int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy))); + if (rc != sizeof(dummy)) { + PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")"; + } + return; + } + } + } +} + +// static +void PairingServerCtx::PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque) { + auto* p = reinterpret_cast<PairingServerCtx*>(opaque); + + ConnectionFinishedEvent event; + if (peer_info != nullptr) { + if (peer_info->type == ADB_RSA_PUB_KEY) { + event = std::make_tuple(fd, peer_info->type, + std::string(reinterpret_cast<const char*>(peer_info->data))); + } else { + LOG(WARNING) << "Ignoring successful pairing because of unknown " + << "PeerInfo type=" << peer_info->type; + } + } else { + event = std::make_tuple(fd, 0, std::nullopt); + } + { + std::lock_guard<std::mutex> lock(p->conn_mutex_); + p->conn_write_queue_.push_back(std::move(event)); + } + p->conn_cv_.notify_one(); +} + +void PairingServerCtx::ConnectionEventsWorker() { + uint8_t num_tries = 0; + for (;;) { + // Transfer the write queue to the read queue. + { + std::unique_lock<std::mutex> lock(conn_mutex_); + ScopedLockAssertion assume_locked(conn_mutex_); + + if (is_terminate_) { + // We check |is_terminate_| twice because condition_variable's + // notify() only wakes up a thread if it is in the wait state + // prior to notify(). Furthermore, we aren't holding the mutex + // when processing the events in |conn_read_queue_|. + return; + } + if (conn_write_queue_.empty()) { + // We need to wait for new events, or the termination signal. + conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) { + return (is_terminate_ || !conn_write_queue_.empty()); + }); + } + if (is_terminate_) { + // We're done. + return; + } + // Move all events into the read queue. + conn_read_queue_ = std::move(conn_write_queue_); + conn_write_queue_.clear(); + } + + // Process all events in the read queue. + while (conn_read_queue_.size() > 0) { + auto& event = conn_read_queue_.front(); + if (auto* p = std::get_if<NewConnectionEvent>(&event)) { + // Ignore if we are already at the max number of connections + if (connections_.size() >= internal::kMaxConnections) { + conn_read_queue_.pop_front(); + continue; + } + auto [ufd, connection] = std::move(*p); + int fd = ufd.release(); + bool started = pairing_connection_start(connection.get(), fd, + PairingConnectionCallback, this); + if (!started) { + LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd; + ufd.reset(fd); + } else { + connections_[fd] = std::move(connection); + } + } else if (auto* p = std::get_if<ConnectionFinishedEvent>(&event)) { + auto [fd, info_type, public_key] = std::move(*p); + if (public_key.has_value() && !public_key->empty()) { + // Valid pairing. Let's shutdown the server and close any + // pairing connections in progress. + StopServer(); + connections_.clear(); + + PeerInfo info = {}; + info.type = info_type; + strncpy(reinterpret_cast<char*>(info.data), public_key->data(), + public_key->size()); + + cb_(&info, opaque_); + + got_valid_pairing_ = true; + return; + } + // Invalid pairing. Close the invalid connection. + if (connections_.find(fd) != connections_.end()) { + connections_.erase(fd); + } + + if (++num_tries >= internal::kMaxPairingAttempts) { + cb_(nullptr, opaque_); + // To prevent the destructor from calling it again. + cb_ = nullptr; + return; + } + } + conn_read_queue_.pop_front(); + } + } +} + +bool PairingServerCtx::HandleNewClientConnection(int fd) { + unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC))); + if (ufd == -1) { + PLOG(WARNING) << "adb_socket_accept failed fd=" << fd; + return false; + } + auto connection = CreatePairingConnection(pswd_, peer_info_, cert_, priv_key_); + if (connection == nullptr) { + LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd; + return false; + } + // send the new connection to the connection thread for further processing + NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection)); + { + std::lock_guard<std::mutex> lock(conn_mutex_); + conn_write_queue_.push_back(std::move(event)); + } + conn_cv_.notify_one(); + + return true; +} + +uint16_t pairing_server_start(PairingServerCtx* ctx, pairing_server_result_cb cb, void* opaque) { + return ctx->Start(cb, opaque); +} + +PairingServerCtx* pairing_server_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, const uint8_t* x509_cert_pem, + size_t x509_size, const uint8_t* priv_key_pem, + size_t priv_size, uint16_t port) { + CHECK(pswd); + CHECK_GT(pswd_len, 0U); + CHECK(x509_cert_pem); + CHECK_GT(x509_size, 0U); + CHECK(priv_key_pem); + CHECK_GT(priv_size, 0U); + CHECK(peer_info); + std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len); + std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); + std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size); + return new PairingServerCtx(vec_pswd, *peer_info, vec_x509_cert, vec_priv_key, port); +} + +PairingServerCtx* pairing_server_new_no_cert(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, uint16_t port) { + auto rsa_2048 = CreateRSA2048Key(); + auto x509_cert = GenerateX509Certificate(rsa_2048->GetEvpPkey()); + std::string pkey_pem = Key::ToPEMString(rsa_2048->GetEvpPkey()); + std::string cert_pem = X509ToPEMString(x509_cert.get()); + + return pairing_server_new(pswd, pswd_len, peer_info, + reinterpret_cast<const uint8_t*>(cert_pem.data()), cert_pem.size(), + reinterpret_cast<const uint8_t*>(pkey_pem.data()), pkey_pem.size(), + port); +} + +void pairing_server_destroy(PairingServerCtx* ctx) { + CHECK(ctx); + delete ctx; +} |