summaryrefslogtreecommitdiff
path: root/identity/support/src
diff options
context:
space:
mode:
Diffstat (limited to 'identity/support/src')
-rw-r--r--identity/support/src/IdentityCredentialSupport.cpp216
1 files changed, 95 insertions, 121 deletions
diff --git a/identity/support/src/IdentityCredentialSupport.cpp b/identity/support/src/IdentityCredentialSupport.cpp
index 093120d032..38348ac1b0 100644
--- a/identity/support/src/IdentityCredentialSupport.cpp
+++ b/identity/support/src/IdentityCredentialSupport.cpp
@@ -344,15 +344,22 @@ string cborPrettyPrint(const vector<uint8_t>& encodedCbor, size_t maxBStrSize,
// Crypto functionality / abstraction.
// ---------------------------------------------------------------------------
-struct EVP_CIPHER_CTX_Deleter {
- void operator()(EVP_CIPHER_CTX* ctx) const {
- if (ctx != nullptr) {
- EVP_CIPHER_CTX_free(ctx);
- }
- }
-};
-
-using EvpCipherCtxPtr = unique_ptr<EVP_CIPHER_CTX, EVP_CIPHER_CTX_Deleter>;
+using EvpCipherCtxPtr = bssl::UniquePtr<EVP_CIPHER_CTX>;
+using EC_KEY_Ptr = bssl::UniquePtr<EC_KEY>;
+using EVP_PKEY_Ptr = bssl::UniquePtr<EVP_PKEY>;
+using EVP_PKEY_CTX_Ptr = bssl::UniquePtr<EVP_PKEY_CTX>;
+using EC_GROUP_Ptr = bssl::UniquePtr<EC_GROUP>;
+using EC_POINT_Ptr = bssl::UniquePtr<EC_POINT>;
+using ECDSA_SIG_Ptr = bssl::UniquePtr<ECDSA_SIG>;
+using X509_Ptr = bssl::UniquePtr<X509>;
+using PKCS12_Ptr = bssl::UniquePtr<PKCS12>;
+using BIGNUM_Ptr = bssl::UniquePtr<BIGNUM>;
+using ASN1_INTEGER_Ptr = bssl::UniquePtr<ASN1_INTEGER>;
+using ASN1_TIME_Ptr = bssl::UniquePtr<ASN1_TIME>;
+using ASN1_OCTET_STRING_Ptr = bssl::UniquePtr<ASN1_OCTET_STRING>;
+using ASN1_OBJECT_Ptr = bssl::UniquePtr<ASN1_OBJECT>;
+using X509_NAME_Ptr = bssl::UniquePtr<X509_NAME>;
+using X509_EXTENSION_Ptr = bssl::UniquePtr<X509_EXTENSION>;
// bool getRandom(size_t numBytes, vector<uint8_t>& output) {
optional<vector<uint8_t>> getRandom(size_t numBytes) {
@@ -534,115 +541,6 @@ optional<vector<uint8_t>> encryptAes128Gcm(const vector<uint8_t>& key, const vec
return encryptedData;
}
-struct EC_KEY_Deleter {
- void operator()(EC_KEY* key) const {
- if (key != nullptr) {
- EC_KEY_free(key);
- }
- }
-};
-using EC_KEY_Ptr = unique_ptr<EC_KEY, EC_KEY_Deleter>;
-
-struct EVP_PKEY_Deleter {
- void operator()(EVP_PKEY* key) const {
- if (key != nullptr) {
- EVP_PKEY_free(key);
- }
- }
-};
-using EVP_PKEY_Ptr = unique_ptr<EVP_PKEY, EVP_PKEY_Deleter>;
-
-struct EVP_PKEY_CTX_Deleter {
- void operator()(EVP_PKEY_CTX* ctx) const {
- if (ctx != nullptr) {
- EVP_PKEY_CTX_free(ctx);
- }
- }
-};
-using EVP_PKEY_CTX_Ptr = unique_ptr<EVP_PKEY_CTX, EVP_PKEY_CTX_Deleter>;
-
-struct EC_GROUP_Deleter {
- void operator()(EC_GROUP* group) const {
- if (group != nullptr) {
- EC_GROUP_free(group);
- }
- }
-};
-using EC_GROUP_Ptr = unique_ptr<EC_GROUP, EC_GROUP_Deleter>;
-
-struct EC_POINT_Deleter {
- void operator()(EC_POINT* point) const {
- if (point != nullptr) {
- EC_POINT_free(point);
- }
- }
-};
-
-using EC_POINT_Ptr = unique_ptr<EC_POINT, EC_POINT_Deleter>;
-
-struct ECDSA_SIG_Deleter {
- void operator()(ECDSA_SIG* sig) const {
- if (sig != nullptr) {
- ECDSA_SIG_free(sig);
- }
- }
-};
-using ECDSA_SIG_Ptr = unique_ptr<ECDSA_SIG, ECDSA_SIG_Deleter>;
-
-struct X509_Deleter {
- void operator()(X509* x509) const {
- if (x509 != nullptr) {
- X509_free(x509);
- }
- }
-};
-using X509_Ptr = unique_ptr<X509, X509_Deleter>;
-
-struct PKCS12_Deleter {
- void operator()(PKCS12* pkcs12) const {
- if (pkcs12 != nullptr) {
- PKCS12_free(pkcs12);
- }
- }
-};
-using PKCS12_Ptr = unique_ptr<PKCS12, PKCS12_Deleter>;
-
-struct BIGNUM_Deleter {
- void operator()(BIGNUM* bignum) const {
- if (bignum != nullptr) {
- BN_free(bignum);
- }
- }
-};
-using BIGNUM_Ptr = unique_ptr<BIGNUM, BIGNUM_Deleter>;
-
-struct ASN1_INTEGER_Deleter {
- void operator()(ASN1_INTEGER* value) const {
- if (value != nullptr) {
- ASN1_INTEGER_free(value);
- }
- }
-};
-using ASN1_INTEGER_Ptr = unique_ptr<ASN1_INTEGER, ASN1_INTEGER_Deleter>;
-
-struct ASN1_TIME_Deleter {
- void operator()(ASN1_TIME* value) const {
- if (value != nullptr) {
- ASN1_TIME_free(value);
- }
- }
-};
-using ASN1_TIME_Ptr = unique_ptr<ASN1_TIME, ASN1_TIME_Deleter>;
-
-struct X509_NAME_Deleter {
- void operator()(X509_NAME* value) const {
- if (value != nullptr) {
- X509_NAME_free(value);
- }
- }
-};
-using X509_NAME_Ptr = unique_ptr<X509_NAME, X509_NAME_Deleter>;
-
vector<uint8_t> certificateChainJoin(const vector<vector<uint8_t>>& certificateChain) {
vector<uint8_t> ret;
for (const vector<uint8_t>& certificate : certificateChain) {
@@ -1221,8 +1119,19 @@ optional<vector<uint8_t>> ecKeyPairGetPrivateKey(const vector<uint8_t>& keyPair)
return {};
}
vector<uint8_t> privateKey;
- privateKey.resize(BN_num_bytes(bignum));
- BN_bn2bin(bignum, privateKey.data());
+
+ // Note that this may return fewer than 32 bytes so pad with zeroes since we
+ // want to always return 32 bytes.
+ size_t numBytes = BN_num_bytes(bignum);
+ if (numBytes > 32) {
+ LOG(ERROR) << "Size is " << numBytes << ", expected this to be 32 or less";
+ return {};
+ }
+ privateKey.resize(32);
+ for (size_t n = 0; n < 32 - numBytes; n++) {
+ privateKey[n] = 0x00;
+ }
+ BN_bn2bin(bignum, privateKey.data() + 32 - numBytes);
return privateKey;
}
@@ -1379,7 +1288,8 @@ optional<vector<uint8_t>> ecKeyPairGetPkcs12(const vector<uint8_t>& keyPair, con
optional<vector<uint8_t>> ecPublicKeyGenerateCertificate(
const vector<uint8_t>& publicKey, const vector<uint8_t>& signingKey,
const string& serialDecimal, const string& issuer, const string& subject,
- time_t validityNotBefore, time_t validityNotAfter) {
+ time_t validityNotBefore, time_t validityNotAfter,
+ const map<string, vector<uint8_t>>& extensions) {
auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
@@ -1482,6 +1392,32 @@ optional<vector<uint8_t>> ecPublicKeyGenerateCertificate(
return {};
}
+ for (auto const& [oidStr, blob] : extensions) {
+ ASN1_OBJECT_Ptr oid(
+ OBJ_txt2obj(oidStr.c_str(), 1)); // accept numerical dotted string form only
+ if (!oid.get()) {
+ LOG(ERROR) << "Error setting OID";
+ return {};
+ }
+ ASN1_OCTET_STRING_Ptr octetString(ASN1_OCTET_STRING_new());
+ if (!ASN1_OCTET_STRING_set(octetString.get(), blob.data(), blob.size())) {
+ LOG(ERROR) << "Error setting octet string for extension";
+ return {};
+ }
+
+ X509_EXTENSION_Ptr extension = X509_EXTENSION_Ptr(X509_EXTENSION_new());
+ extension.reset(X509_EXTENSION_create_by_OBJ(nullptr, oid.get(), 0 /* not critical */,
+ octetString.get()));
+ if (!extension.get()) {
+ LOG(ERROR) << "Error setting extension";
+ return {};
+ }
+ if (!X509_add_ext(x509.get(), extension.get(), -1)) {
+ LOG(ERROR) << "Error adding extension";
+ return {};
+ }
+ }
+
if (X509_sign(x509.get(), privPkey.get(), EVP_sha256()) == 0) {
LOG(ERROR) << "Error signing X509 certificate";
return {};
@@ -1650,6 +1586,44 @@ optional<vector<uint8_t>> certificateChainGetTopMostKey(const vector<uint8_t>& c
return publicKey;
}
+optional<vector<uint8_t>> certificateGetExtension(const vector<uint8_t>& x509Certificate,
+ const string& oidStr) {
+ vector<X509_Ptr> certs;
+ if (!parseX509Certificates(x509Certificate, certs)) {
+ return {};
+ }
+ if (certs.size() < 1) {
+ LOG(ERROR) << "No certificates in chain";
+ return {};
+ }
+
+ ASN1_OBJECT_Ptr oid(
+ OBJ_txt2obj(oidStr.c_str(), 1)); // accept numerical dotted string form only
+ if (!oid.get()) {
+ LOG(ERROR) << "Error setting OID";
+ return {};
+ }
+
+ int location = X509_get_ext_by_OBJ(certs[0].get(), oid.get(), -1 /* search from beginning */);
+ if (location == -1) {
+ return {};
+ }
+
+ X509_EXTENSION* ext = X509_get_ext(certs[0].get(), location);
+ if (ext == nullptr) {
+ return {};
+ }
+
+ ASN1_OCTET_STRING* octetString = X509_EXTENSION_get_data(ext);
+ if (octetString == nullptr) {
+ return {};
+ }
+ vector<uint8_t> result;
+ result.resize(octetString->length);
+ memcpy(result.data(), octetString->data, octetString->length);
+ return result;
+}
+
optional<pair<size_t, size_t>> certificateFindPublicKey(const vector<uint8_t>& x509Certificate) {
vector<X509_Ptr> certs;
if (!parseX509Certificates(x509Certificate, certs)) {