summaryrefslogtreecommitdiff
path: root/adb/pairing_connection/pairing_server.cpp
diff options
context:
space:
mode:
authorJoshua Duong <joshuaduong@google.com>2020-02-07 11:06:16 -0800
committerJoshua Duong <joshuaduong@google.com>2020-02-21 21:06:12 +0000
commitc7a1fb8fd970c6fa41843f48cc7bcc02fd07800d (patch)
tree802bc92dc2935c5a13681a117f17168aed96d3e1 /adb/pairing_connection/pairing_server.cpp
parentdf8f1217d09de155420371dee30ebf5d884d1cd7 (diff)
[adbwifi] Add pairing_connection library.
Bug: 111434128 Bug: 119494503 Test: atest adb_pairing_connection_test Change-Id: I54d68c65067809832266d6c3043b63222c98a9cd Exempt-From-Owner-Approval: approved already
Diffstat (limited to 'adb/pairing_connection/pairing_server.cpp')
-rw-r--r--adb/pairing_connection/pairing_server.cpp466
1 files changed, 466 insertions, 0 deletions
diff --git a/adb/pairing_connection/pairing_server.cpp b/adb/pairing_connection/pairing_server.cpp
new file mode 100644
index 000000000..7218eacf2
--- /dev/null
+++ b/adb/pairing_connection/pairing_server.cpp
@@ -0,0 +1,466 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "adb/pairing/pairing_server.h"
+
+#include <sys/epoll.h>
+#include <sys/eventfd.h>
+
+#include <atomic>
+#include <deque>
+#include <iomanip>
+#include <mutex>
+#include <sstream>
+#include <thread>
+#include <tuple>
+#include <unordered_map>
+#include <variant>
+#include <vector>
+
+#include <adb/crypto/rsa_2048_key.h>
+#include <adb/crypto/x509_generator.h>
+#include <adb/pairing/pairing_connection.h>
+#include <android-base/logging.h>
+#include <android-base/parsenetaddress.h>
+#include <android-base/thread_annotations.h>
+#include <android-base/unique_fd.h>
+#include <cutils/sockets.h>
+
+#include "internal/constants.h"
+
+using android::base::ScopedLockAssertion;
+using android::base::unique_fd;
+using namespace adb::crypto;
+using namespace adb::pairing;
+
+// The implementation has two background threads running: one to handle and
+// accept any new pairing connection requests (socket accept), and the other to
+// handle connection events (connection started, connection finished).
+struct PairingServerCtx {
+ public:
+ using Data = std::vector<uint8_t>;
+
+ virtual ~PairingServerCtx();
+
+ // All parameters must be non-empty.
+ explicit PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
+ const Data& priv_key, uint16_t port);
+
+ // Starts the pairing server. This call is non-blocking. Upon completion,
+ // if the pairing was successful, then |cb| will be called with the PublicKeyHeader
+ // containing the info of the trusted peer. Otherwise, |cb| will be
+ // called with an empty value. Start can only be called once in the lifetime
+ // of this object.
+ //
+ // Returns the port number if PairingServerCtx was successfully started. Otherwise,
+ // returns 0.
+ uint16_t Start(pairing_server_result_cb cb, void* opaque);
+
+ private:
+ // Setup the server socket to accept incoming connections. Returns the
+ // server port number (> 0 on success).
+ uint16_t SetupServer();
+ // Force stop the server thread.
+ void StopServer();
+
+ // handles a new pairing client connection
+ bool HandleNewClientConnection(int fd) EXCLUDES(conn_mutex_);
+
+ // ======== connection events thread =============
+ std::mutex conn_mutex_;
+ std::condition_variable conn_cv_;
+
+ using FdVal = int;
+ struct ConnectionDeleter {
+ void operator()(PairingConnectionCtx* p) { pairing_connection_destroy(p); }
+ };
+ using ConnectionPtr = std::unique_ptr<PairingConnectionCtx, ConnectionDeleter>;
+ static ConnectionPtr CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info,
+ const Data& cert, const Data& priv_key);
+ using NewConnectionEvent = std::tuple<unique_fd, ConnectionPtr>;
+ // <fd, PeerInfo.type, PeerInfo.data>
+ using ConnectionFinishedEvent = std::tuple<FdVal, uint8_t, std::optional<std::string>>;
+ using ConnectionEvent = std::variant<NewConnectionEvent, ConnectionFinishedEvent>;
+ // Queue for connections to write into. We have a separate queue to read
+ // from, in order to minimize the time the server thread is blocked.
+ std::deque<ConnectionEvent> conn_write_queue_ GUARDED_BY(conn_mutex_);
+ std::deque<ConnectionEvent> conn_read_queue_;
+ // Map of fds to their PairingConnections currently running.
+ std::unordered_map<FdVal, ConnectionPtr> connections_;
+
+ // Two threads launched when starting the pairing server:
+ // 1) A server thread that waits for incoming client connections, and
+ // 2) A connection events thread that synchonizes events from all of the
+ // clients, since each PairingConnection is running in it's own thread.
+ void StartConnectionEventsThread();
+ void StartServerThread();
+
+ static void PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque);
+
+ std::thread conn_events_thread_;
+ void ConnectionEventsWorker();
+ std::thread server_thread_;
+ void ServerWorker();
+ bool is_terminate_ GUARDED_BY(conn_mutex_) = false;
+
+ enum class State {
+ Ready,
+ Running,
+ Stopped,
+ };
+ State state_ = State::Ready;
+ Data pswd_;
+ PeerInfo peer_info_;
+ Data cert_;
+ Data priv_key_;
+ uint16_t port_;
+
+ pairing_server_result_cb cb_;
+ void* opaque_ = nullptr;
+ bool got_valid_pairing_ = false;
+
+ static const int kEpollConstSocket = 0;
+ // Used to break the server thread from epoll_wait
+ static const int kEpollConstEventFd = 1;
+ unique_fd epoll_fd_;
+ unique_fd server_fd_;
+ unique_fd event_fd_;
+}; // PairingServerCtx
+
+// static
+PairingServerCtx::ConnectionPtr PairingServerCtx::CreatePairingConnection(const Data& pswd,
+ const PeerInfo& peer_info,
+ const Data& cert,
+ const Data& priv_key) {
+ return ConnectionPtr(pairing_connection_server_new(pswd.data(), pswd.size(), &peer_info,
+ cert.data(), cert.size(), priv_key.data(),
+ priv_key.size()));
+}
+
+PairingServerCtx::PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
+ const Data& priv_key, uint16_t port)
+ : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) {
+ CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
+}
+
+PairingServerCtx::~PairingServerCtx() {
+ // Since these connections have references to us, let's make sure they
+ // destruct before us.
+ if (server_thread_.joinable()) {
+ StopServer();
+ server_thread_.join();
+ }
+
+ {
+ std::lock_guard<std::mutex> lock(conn_mutex_);
+ is_terminate_ = true;
+ }
+ conn_cv_.notify_one();
+ if (conn_events_thread_.joinable()) {
+ conn_events_thread_.join();
+ }
+
+ // Notify the cb_ if it hasn't already.
+ if (!got_valid_pairing_ && cb_ != nullptr) {
+ cb_(nullptr, opaque_);
+ }
+}
+
+uint16_t PairingServerCtx::Start(pairing_server_result_cb cb, void* opaque) {
+ cb_ = cb;
+ opaque_ = opaque;
+
+ if (state_ != State::Ready) {
+ LOG(ERROR) << "PairingServerCtx already running or stopped";
+ return 0;
+ }
+
+ port_ = SetupServer();
+ if (port_ == 0) {
+ LOG(ERROR) << "Unable to start PairingServer";
+ state_ = State::Stopped;
+ return 0;
+ }
+ LOG(INFO) << "Pairing server started on port " << port_;
+
+ state_ = State::Running;
+ return port_;
+}
+
+void PairingServerCtx::StopServer() {
+ if (event_fd_.get() == -1) {
+ return;
+ }
+ uint64_t value = 1;
+ ssize_t rc = write(event_fd_.get(), &value, sizeof(value));
+ if (rc == -1) {
+ // This can happen if the server didn't start.
+ PLOG(ERROR) << "write to eventfd failed";
+ } else if (rc != sizeof(value)) {
+ LOG(FATAL) << "write to event returned short (" << rc << ")";
+ }
+}
+
+uint16_t PairingServerCtx::SetupServer() {
+ epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
+ if (epoll_fd_ == -1) {
+ PLOG(ERROR) << "failed to create epoll fd";
+ return 0;
+ }
+
+ event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
+ if (event_fd_ == -1) {
+ PLOG(ERROR) << "failed to create eventfd";
+ return 0;
+ }
+
+ server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM));
+ if (server_fd_.get() == -1) {
+ PLOG(ERROR) << "Failed to start pairing connection server";
+ return 0;
+ } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) {
+ PLOG(ERROR) << "Failed to make server socket cloexec";
+ return 0;
+ } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) {
+ PLOG(ERROR) << "Failed to make server socket nonblocking";
+ return 0;
+ }
+
+ StartConnectionEventsThread();
+ StartServerThread();
+ int port = socket_get_local_port(server_fd_.get());
+ return (port <= 0 ? 0 : port);
+}
+
+void PairingServerCtx::StartServerThread() {
+ server_thread_ = std::thread([this]() { ServerWorker(); });
+}
+
+void PairingServerCtx::StartConnectionEventsThread() {
+ conn_events_thread_ = std::thread([this]() { ConnectionEventsWorker(); });
+}
+
+void PairingServerCtx::ServerWorker() {
+ {
+ struct epoll_event event;
+ event.events = EPOLLIN;
+ event.data.u64 = kEpollConstSocket;
+ CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event));
+ }
+
+ {
+ struct epoll_event event;
+ event.events = EPOLLIN;
+ event.data.u64 = kEpollConstEventFd;
+ CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event));
+ }
+
+ while (true) {
+ struct epoll_event events[2];
+ int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1));
+ if (rc == -1) {
+ PLOG(ERROR) << "epoll_wait failed";
+ return;
+ } else if (rc == 0) {
+ LOG(ERROR) << "epoll_wait returned 0";
+ return;
+ }
+
+ for (int i = 0; i < rc; ++i) {
+ struct epoll_event& event = events[i];
+ switch (event.data.u64) {
+ case kEpollConstSocket:
+ HandleNewClientConnection(server_fd_.get());
+ break;
+ case kEpollConstEventFd:
+ uint64_t dummy;
+ int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy)));
+ if (rc != sizeof(dummy)) {
+ PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")";
+ }
+ return;
+ }
+ }
+ }
+}
+
+// static
+void PairingServerCtx::PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque) {
+ auto* p = reinterpret_cast<PairingServerCtx*>(opaque);
+
+ ConnectionFinishedEvent event;
+ if (peer_info != nullptr) {
+ if (peer_info->type == ADB_RSA_PUB_KEY) {
+ event = std::make_tuple(fd, peer_info->type,
+ std::string(reinterpret_cast<const char*>(peer_info->data)));
+ } else {
+ LOG(WARNING) << "Ignoring successful pairing because of unknown "
+ << "PeerInfo type=" << peer_info->type;
+ }
+ } else {
+ event = std::make_tuple(fd, 0, std::nullopt);
+ }
+ {
+ std::lock_guard<std::mutex> lock(p->conn_mutex_);
+ p->conn_write_queue_.push_back(std::move(event));
+ }
+ p->conn_cv_.notify_one();
+}
+
+void PairingServerCtx::ConnectionEventsWorker() {
+ uint8_t num_tries = 0;
+ for (;;) {
+ // Transfer the write queue to the read queue.
+ {
+ std::unique_lock<std::mutex> lock(conn_mutex_);
+ ScopedLockAssertion assume_locked(conn_mutex_);
+
+ if (is_terminate_) {
+ // We check |is_terminate_| twice because condition_variable's
+ // notify() only wakes up a thread if it is in the wait state
+ // prior to notify(). Furthermore, we aren't holding the mutex
+ // when processing the events in |conn_read_queue_|.
+ return;
+ }
+ if (conn_write_queue_.empty()) {
+ // We need to wait for new events, or the termination signal.
+ conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) {
+ return (is_terminate_ || !conn_write_queue_.empty());
+ });
+ }
+ if (is_terminate_) {
+ // We're done.
+ return;
+ }
+ // Move all events into the read queue.
+ conn_read_queue_ = std::move(conn_write_queue_);
+ conn_write_queue_.clear();
+ }
+
+ // Process all events in the read queue.
+ while (conn_read_queue_.size() > 0) {
+ auto& event = conn_read_queue_.front();
+ if (auto* p = std::get_if<NewConnectionEvent>(&event)) {
+ // Ignore if we are already at the max number of connections
+ if (connections_.size() >= internal::kMaxConnections) {
+ conn_read_queue_.pop_front();
+ continue;
+ }
+ auto [ufd, connection] = std::move(*p);
+ int fd = ufd.release();
+ bool started = pairing_connection_start(connection.get(), fd,
+ PairingConnectionCallback, this);
+ if (!started) {
+ LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd;
+ ufd.reset(fd);
+ } else {
+ connections_[fd] = std::move(connection);
+ }
+ } else if (auto* p = std::get_if<ConnectionFinishedEvent>(&event)) {
+ auto [fd, info_type, public_key] = std::move(*p);
+ if (public_key.has_value() && !public_key->empty()) {
+ // Valid pairing. Let's shutdown the server and close any
+ // pairing connections in progress.
+ StopServer();
+ connections_.clear();
+
+ PeerInfo info = {};
+ info.type = info_type;
+ strncpy(reinterpret_cast<char*>(info.data), public_key->data(),
+ public_key->size());
+
+ cb_(&info, opaque_);
+
+ got_valid_pairing_ = true;
+ return;
+ }
+ // Invalid pairing. Close the invalid connection.
+ if (connections_.find(fd) != connections_.end()) {
+ connections_.erase(fd);
+ }
+
+ if (++num_tries >= internal::kMaxPairingAttempts) {
+ cb_(nullptr, opaque_);
+ // To prevent the destructor from calling it again.
+ cb_ = nullptr;
+ return;
+ }
+ }
+ conn_read_queue_.pop_front();
+ }
+ }
+}
+
+bool PairingServerCtx::HandleNewClientConnection(int fd) {
+ unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC)));
+ if (ufd == -1) {
+ PLOG(WARNING) << "adb_socket_accept failed fd=" << fd;
+ return false;
+ }
+ auto connection = CreatePairingConnection(pswd_, peer_info_, cert_, priv_key_);
+ if (connection == nullptr) {
+ LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd;
+ return false;
+ }
+ // send the new connection to the connection thread for further processing
+ NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection));
+ {
+ std::lock_guard<std::mutex> lock(conn_mutex_);
+ conn_write_queue_.push_back(std::move(event));
+ }
+ conn_cv_.notify_one();
+
+ return true;
+}
+
+uint16_t pairing_server_start(PairingServerCtx* ctx, pairing_server_result_cb cb, void* opaque) {
+ return ctx->Start(cb, opaque);
+}
+
+PairingServerCtx* pairing_server_new(const uint8_t* pswd, size_t pswd_len,
+ const PeerInfo* peer_info, const uint8_t* x509_cert_pem,
+ size_t x509_size, const uint8_t* priv_key_pem,
+ size_t priv_size, uint16_t port) {
+ CHECK(pswd);
+ CHECK_GT(pswd_len, 0U);
+ CHECK(x509_cert_pem);
+ CHECK_GT(x509_size, 0U);
+ CHECK(priv_key_pem);
+ CHECK_GT(priv_size, 0U);
+ CHECK(peer_info);
+ std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
+ std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
+ std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
+ return new PairingServerCtx(vec_pswd, *peer_info, vec_x509_cert, vec_priv_key, port);
+}
+
+PairingServerCtx* pairing_server_new_no_cert(const uint8_t* pswd, size_t pswd_len,
+ const PeerInfo* peer_info, uint16_t port) {
+ auto rsa_2048 = CreateRSA2048Key();
+ auto x509_cert = GenerateX509Certificate(rsa_2048->GetEvpPkey());
+ std::string pkey_pem = Key::ToPEMString(rsa_2048->GetEvpPkey());
+ std::string cert_pem = X509ToPEMString(x509_cert.get());
+
+ return pairing_server_new(pswd, pswd_len, peer_info,
+ reinterpret_cast<const uint8_t*>(cert_pem.data()), cert_pem.size(),
+ reinterpret_cast<const uint8_t*>(pkey_pem.data()), pkey_pem.size(),
+ port);
+}
+
+void pairing_server_destroy(PairingServerCtx* ctx) {
+ CHECK(ctx);
+ delete ctx;
+}