summaryrefslogtreecommitdiff
path: root/adb/pairing_connection/pairing_connection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'adb/pairing_connection/pairing_connection.cpp')
-rw-r--r--adb/pairing_connection/pairing_connection.cpp491
1 files changed, 491 insertions, 0 deletions
diff --git a/adb/pairing_connection/pairing_connection.cpp b/adb/pairing_connection/pairing_connection.cpp
new file mode 100644
index 000000000..a26a6b4d2
--- /dev/null
+++ b/adb/pairing_connection/pairing_connection.cpp
@@ -0,0 +1,491 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "adb/pairing/pairing_connection.h"
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <functional>
+#include <memory>
+#include <string_view>
+#include <thread>
+#include <vector>
+
+#include <adb/pairing/pairing_auth.h>
+#include <adb/tls/tls_connection.h>
+#include <android-base/endian.h>
+#include <android-base/logging.h>
+#include <android-base/macros.h>
+#include <android-base/unique_fd.h>
+
+#include "pairing.pb.h"
+
+using namespace adb;
+using android::base::unique_fd;
+using TlsError = tls::TlsConnection::TlsError;
+
+const uint8_t kCurrentKeyHeaderVersion = 1;
+const uint8_t kMinSupportedKeyHeaderVersion = 1;
+const uint8_t kMaxSupportedKeyHeaderVersion = 1;
+const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2;
+
+struct PairingPacketHeader {
+ uint8_t version; // PairingPacket version
+ uint8_t type; // the type of packet (PairingPacket.Type)
+ uint32_t payload; // Size of the payload in bytes
+} __attribute__((packed));
+
+struct PairingAuthDeleter {
+ void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); }
+}; // PairingAuthDeleter
+using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>;
+
+// PairingConnectionCtx encapsulates the protocol to authenticate two peers with
+// each other. This class will open the tcp sockets and handle the pairing
+// process. On completion, both sides will have each other's public key
+// (certificate) if successful, otherwise, the pairing failed. The tcp port
+// number is hardcoded (see pairing_connection.cpp).
+//
+// Each PairingConnectionCtx instance represents a different device trying to
+// pair. So for the device, we can have multiple PairingConnectionCtxs while the
+// host may have only one (unless host has a PairingServer).
+//
+// See pairing_connection_test.cpp for example usage.
+//
+struct PairingConnectionCtx {
+ public:
+ using Data = std::vector<uint8_t>;
+ using ResultCallback = pairing_result_cb;
+ enum class Role {
+ Client,
+ Server,
+ };
+
+ explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
+ const Data& certificate, const Data& priv_key);
+ virtual ~PairingConnectionCtx();
+
+ // Starts the pairing connection on a separate thread.
+ // Upon completion, if the pairing was successful,
+ // |cb| will be called with the peer information and certificate.
+ // Otherwise, |cb| will be called with empty data. |fd| should already
+ // be opened. PairingConnectionCtx will take ownership of the |fd|.
+ //
+ // Pairing is successful if both server/client uses the same non-empty
+ // |pswd|, and they are able to exchange the information. |pswd| and
+ // |certificate| must be non-empty. Start() can only be called once in the
+ // lifetime of this object.
+ //
+ // Returns true if the thread was successfully started, false otherwise.
+ bool Start(int fd, ResultCallback cb, void* opaque);
+
+ private:
+ // Setup the tls connection.
+ bool SetupTlsConnection();
+
+ /************ PairingPacketHeader methods ****************/
+ // Tries to write out the header and payload.
+ bool WriteHeader(const PairingPacketHeader* header, std::string_view payload);
+ // Tries to parse incoming data into the |header|. Returns true if header
+ // is valid and header version is supported. |header| is filled on success.
+ // |header| may contain garbage if unsuccessful.
+ bool ReadHeader(PairingPacketHeader* header);
+ // Creates a PairingPacketHeader.
+ void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type,
+ uint32_t payload_size);
+ // Checks if actual matches expected.
+ bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual);
+
+ /*********** State related methods **************/
+ // Handles the State::ExchangingMsgs state.
+ bool DoExchangeMsgs();
+ // Handles the State::ExchangingPeerInfo state.
+ bool DoExchangePeerInfo();
+
+ // The background task to do the pairing.
+ void StartWorker();
+
+ // Calls |cb_| and sets the state to Stopped.
+ void NotifyResult(const PeerInfo* p);
+
+ static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd);
+
+ enum class State {
+ Ready,
+ ExchangingMsgs,
+ ExchangingPeerInfo,
+ Stopped,
+ };
+
+ std::atomic<State> state_{State::Ready};
+ Role role_;
+ Data pswd_;
+ PeerInfo peer_info_;
+ Data cert_;
+ Data priv_key_;
+
+ // Peer's info
+ PeerInfo their_info_;
+
+ ResultCallback cb_;
+ void* opaque_ = nullptr;
+ std::unique_ptr<tls::TlsConnection> tls_;
+ PairingAuthPtr auth_;
+ unique_fd fd_;
+ std::thread thread_;
+ static constexpr size_t kExportedKeySize = 64;
+}; // PairingConnectionCtx
+
+PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
+ const Data& cert, const Data& priv_key)
+ : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) {
+ CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
+}
+
+PairingConnectionCtx::~PairingConnectionCtx() {
+ // Force close the fd and wait for the worker thread to finish.
+ fd_.reset();
+ if (thread_.joinable()) {
+ thread_.join();
+ }
+}
+
+bool PairingConnectionCtx::SetupTlsConnection() {
+ tls_ = tls::TlsConnection::Create(
+ role_ == Role::Server ? tls::TlsConnection::Role::Server
+ : tls::TlsConnection::Role::Client,
+ std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()),
+ std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()),
+ fd_);
+
+ if (tls_ == nullptr) {
+ LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get();
+ return false;
+ }
+
+ // Allow any peer certificate
+ tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
+
+ // SSL doesn't seem to behave correctly with fdevents so just do a blocking
+ // read for the pairing data.
+ if (tls_->DoHandshake() != TlsError::Success) {
+ LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get();
+ return false;
+ }
+
+ // To ensure the connection is not stolen while we do the PAKE, append the
+ // exported key material from the tls connection to the password.
+ std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize);
+ if (exportedKeyMaterial.empty()) {
+ LOG(ERROR) << "Failed to export key material";
+ return false;
+ }
+ pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()),
+ std::make_move_iterator(exportedKeyMaterial.end()));
+ auth_ = CreatePairingAuthPtr(role_, pswd_);
+
+ return true;
+}
+
+bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header,
+ std::string_view payload) {
+ PairingPacketHeader network_header = *header;
+ network_header.payload = htonl(network_header.payload);
+ if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header),
+ sizeof(PairingPacketHeader))) ||
+ !tls_->WriteFully(payload)) {
+ LOG(ERROR) << "Failed to write out PairingPacketHeader";
+ state_ = State::Stopped;
+ return false;
+ }
+ return true;
+}
+
+bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) {
+ auto data = tls_->ReadFully(sizeof(PairingPacketHeader));
+ if (data.empty()) {
+ return false;
+ }
+
+ uint8_t* p = data.data();
+ // First byte is always PairingPacketHeader version
+ header->version = *p;
+ ++p;
+ if (header->version < kMinSupportedKeyHeaderVersion ||
+ header->version > kMaxSupportedKeyHeaderVersion) {
+ LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion
+ << " them=" << header->version << ")";
+ return false;
+ }
+ // Next byte is the PairingPacket::Type
+ if (!adb::proto::PairingPacket::Type_IsValid(*p)) {
+ LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p);
+ return false;
+ }
+ header->type = *p;
+ ++p;
+ // Last, the payload size
+ header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p)));
+ if (header->payload == 0 || header->payload > kMaxPayloadSize) {
+ LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload
+ << ")";
+ return false;
+ }
+
+ return true;
+}
+
+void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header,
+ adb::proto::PairingPacket::Type type,
+ uint32_t payload_size) {
+ header->version = kCurrentKeyHeaderVersion;
+ uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type));
+ header->type = type8;
+ header->payload = payload_size;
+}
+
+bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type,
+ uint8_t actual) {
+ uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type);
+ if (actual != expected) {
+ LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected)
+ << " actual=" << static_cast<uint32_t>(actual) << ")";
+ return false;
+ }
+ return true;
+}
+
+void PairingConnectionCtx::NotifyResult(const PeerInfo* p) {
+ cb_(p, fd_.get(), opaque_);
+ state_ = State::Stopped;
+}
+
+bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) {
+ if (fd < 0) {
+ return false;
+ }
+
+ State expected = State::Ready;
+ if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) {
+ return false;
+ }
+
+ fd_.reset(fd);
+ cb_ = cb;
+ opaque_ = opaque;
+
+ thread_ = std::thread([this] { StartWorker(); });
+ return true;
+}
+
+bool PairingConnectionCtx::DoExchangeMsgs() {
+ uint32_t payload = pairing_auth_msg_size(auth_.get());
+ std::vector<uint8_t> msg(payload);
+ pairing_auth_get_spake2_msg(auth_.get(), msg.data());
+
+ PairingPacketHeader header;
+ CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload);
+
+ // Write our SPAKE2 msg
+ if (!WriteHeader(&header,
+ std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) {
+ LOG(ERROR) << "Failed to write SPAKE2 msg.";
+ return false;
+ }
+
+ // Read the peer's SPAKE2 msg header
+ if (!ReadHeader(&header)) {
+ LOG(ERROR) << "Invalid PairingPacketHeader.";
+ return false;
+ }
+ if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) {
+ return false;
+ }
+
+ // Read the SPAKE2 msg payload and initialize the cipher for
+ // encrypting the PeerInfo and certificate.
+ auto their_msg = tls_->ReadFully(header.payload);
+ if (their_msg.empty() ||
+ !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) {
+ LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size()
+ << "]";
+ return false;
+ }
+
+ return true;
+}
+
+bool PairingConnectionCtx::DoExchangePeerInfo() {
+ // Encrypt PeerInfo
+ std::vector<uint8_t> buf;
+ uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_);
+ buf.assign(p, p + sizeof(peer_info_));
+ std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size()));
+ CHECK(!outbuf.empty());
+ size_t outsize;
+ if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
+ LOG(ERROR) << "Failed to encrypt peer info";
+ return false;
+ }
+ outbuf.resize(outsize);
+
+ // Write out the packet header
+ PairingPacketHeader out_header;
+ out_header.version = kCurrentKeyHeaderVersion;
+ out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO));
+ out_header.payload = htonl(outbuf.size());
+ if (!tls_->WriteFully(
+ std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) {
+ LOG(ERROR) << "Unable to write PairingPacketHeader";
+ return false;
+ }
+
+ // Write out the encrypted payload
+ if (!tls_->WriteFully(
+ std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) {
+ LOG(ERROR) << "Unable to write encrypted peer info";
+ return false;
+ }
+
+ // Read in the peer's packet header
+ PairingPacketHeader header;
+ if (!ReadHeader(&header)) {
+ LOG(ERROR) << "Invalid PairingPacketHeader.";
+ return false;
+ }
+
+ if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) {
+ return false;
+ }
+
+ // Read in the encrypted peer certificate
+ buf = tls_->ReadFully(header.payload);
+ if (buf.empty()) {
+ return false;
+ }
+
+ // Try to decrypt the certificate
+ outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size()));
+ if (outbuf.empty()) {
+ LOG(ERROR) << "Unsupported payload while decrypting peer info.";
+ return false;
+ }
+
+ if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
+ LOG(ERROR) << "Failed to decrypt";
+ return false;
+ }
+ outbuf.resize(outsize);
+
+ // The decrypted message should contain the PeerInfo.
+ if (outbuf.size() != sizeof(PeerInfo)) {
+ LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo);
+ return false;
+ }
+
+ p = outbuf.data();
+ ::memcpy(&their_info_, p, sizeof(PeerInfo));
+ p += sizeof(PeerInfo);
+
+ return true;
+}
+
+void PairingConnectionCtx::StartWorker() {
+ // Setup the secure transport
+ if (!SetupTlsConnection()) {
+ NotifyResult(nullptr);
+ return;
+ }
+
+ for (;;) {
+ switch (state_) {
+ case State::ExchangingMsgs:
+ if (!DoExchangeMsgs()) {
+ NotifyResult(nullptr);
+ return;
+ }
+ state_ = State::ExchangingPeerInfo;
+ break;
+ case State::ExchangingPeerInfo:
+ if (!DoExchangePeerInfo()) {
+ NotifyResult(nullptr);
+ return;
+ }
+ NotifyResult(&their_info_);
+ return;
+ case State::Ready:
+ case State::Stopped:
+ LOG(FATAL) << __func__ << ": Got invalid state";
+ return;
+ }
+ }
+}
+
+// static
+PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) {
+ switch (role) {
+ case Role::Client:
+ return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size()));
+ break;
+ case Role::Server:
+ return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size()));
+ break;
+ }
+}
+
+static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd,
+ size_t pswd_len, const PeerInfo* peer_info,
+ const uint8_t* x509_cert_pem, size_t x509_size,
+ const uint8_t* priv_key_pem, size_t priv_size) {
+ CHECK(pswd);
+ CHECK_GT(pswd_len, 0U);
+ CHECK(x509_cert_pem);
+ CHECK_GT(x509_size, 0U);
+ CHECK(priv_key_pem);
+ CHECK_GT(priv_size, 0U);
+ CHECK(peer_info);
+ std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
+ std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
+ std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
+ return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key);
+}
+
+PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len,
+ const PeerInfo* peer_info,
+ const uint8_t* x509_cert_pem, size_t x509_size,
+ const uint8_t* priv_key_pem, size_t priv_size) {
+ return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info,
+ x509_cert_pem, x509_size, priv_key_pem, priv_size);
+}
+
+PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len,
+ const PeerInfo* peer_info,
+ const uint8_t* x509_cert_pem, size_t x509_size,
+ const uint8_t* priv_key_pem, size_t priv_size) {
+ return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info,
+ x509_cert_pem, x509_size, priv_key_pem, priv_size);
+}
+
+void pairing_connection_destroy(PairingConnectionCtx* ctx) {
+ CHECK(ctx);
+ delete ctx;
+}
+
+bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb,
+ void* opaque) {
+ return ctx->Start(fd, cb, opaque);
+}