diff options
author | Michael Butler <butlermichael@google.com> | 2021-04-05 23:03:24 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-04-05 23:03:24 +0000 |
commit | 564f4060167b5886d2f712d1ec2321846de651fd (patch) | |
tree | b55dc57d4d2fdf75080b3026444a22c8d7f9a49d /neuralnetworks/aidl/utils/src/Conversions.cpp | |
parent | a98f551aad1c49678c6b52bdbeda0fe1c7626e41 (diff) | |
parent | 0f10ee9eb3becffb50e2c037ff0c141cd620f051 (diff) |
Add missing validation for NN canonical types am: 388bcebc8f am: 2f59ea0610 am: 0f10ee9eb3
Original change: https://android-review.googlesource.com/c/platform/hardware/interfaces/+/1663919
Change-Id: I98a53bd66ff52ece26e06e1457592b726a3f9468
Diffstat (limited to 'neuralnetworks/aidl/utils/src/Conversions.cpp')
-rw-r--r-- | neuralnetworks/aidl/utils/src/Conversions.cpp | 93 |
1 files changed, 30 insertions, 63 deletions
diff --git a/neuralnetworks/aidl/utils/src/Conversions.cpp b/neuralnetworks/aidl/utils/src/Conversions.cpp index 45bc005e9f..c74c509a8d 100644 --- a/neuralnetworks/aidl/utils/src/Conversions.cpp +++ b/neuralnetworks/aidl/utils/src/Conversions.cpp @@ -41,6 +41,8 @@ #include <type_traits> #include <utility> +#include "Utils.h" + #define VERIFY_NON_NEGATIVE(value) \ while (UNLIKELY(value < 0)) return NN_ERROR() @@ -53,7 +55,6 @@ constexpr std::underlying_type_t<Type> underlyingType(Type value) { return static_cast<std::underlying_type_t<Type>>(value); } -constexpr auto kVersion = android::nn::Version::ANDROID_S; constexpr int64_t kNoTiming = -1; } // namespace @@ -63,32 +64,6 @@ namespace { using ::aidl::android::hardware::common::NativeHandle; -constexpr auto validOperandType(nn::OperandType operandType) { - switch (operandType) { - case nn::OperandType::FLOAT32: - case nn::OperandType::INT32: - case nn::OperandType::UINT32: - case nn::OperandType::TENSOR_FLOAT32: - case nn::OperandType::TENSOR_INT32: - case nn::OperandType::TENSOR_QUANT8_ASYMM: - case nn::OperandType::BOOL: - case nn::OperandType::TENSOR_QUANT16_SYMM: - case nn::OperandType::TENSOR_FLOAT16: - case nn::OperandType::TENSOR_BOOL8: - case nn::OperandType::FLOAT16: - case nn::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: - case nn::OperandType::TENSOR_QUANT16_ASYMM: - case nn::OperandType::TENSOR_QUANT8_SYMM: - case nn::OperandType::TENSOR_QUANT8_ASYMM_SIGNED: - case nn::OperandType::SUBGRAPH: - return true; - case nn::OperandType::OEM: - case nn::OperandType::TENSOR_OEM_BYTE: - return false; - } - return nn::isExtension(operandType); -} - template <typename Input> using UnvalidatedConvertOutput = std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>; @@ -113,14 +88,7 @@ GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert( template <typename Type> GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) { auto canonical = NN_TRY(nn::unvalidatedConvert(halObject)); - const auto maybeVersion = validate(canonical); - if (!maybeVersion.has_value()) { - return error() << maybeVersion.error(); - } - const auto version = maybeVersion.value(); - if (version > kVersion) { - return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion; - } + NN_TRY(aidl_hal::utils::compliantVersion(canonical)); return canonical; } @@ -185,13 +153,21 @@ static GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const Native GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) { VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed."; - return static_cast<OperandType>(operandType); + const auto canonical = static_cast<OperandType>(operandType); + if (canonical == OperandType::OEM || canonical == OperandType::TENSOR_OEM_BYTE) { + return NN_ERROR() << "Unable to convert invalid OperandType " << canonical; + } + return canonical; } GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) { VERIFY_NON_NEGATIVE(underlyingType(operationType)) << "Negative operation types are not allowed."; - return static_cast<OperationType>(operationType); + const auto canonical = static_cast<OperationType>(operationType); + if (canonical == OperationType::OEM_OPERATION) { + return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION"; + } + return canonical; } GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) { @@ -206,8 +182,7 @@ GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& cap const bool validOperandTypes = std::all_of( capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(), [](const aidl_hal::OperandPerformance& operandPerformance) { - const auto maybeType = unvalidatedConvert(operandPerformance.type); - return !maybeType.has_value() ? false : validOperandType(maybeType.value()); + return validatedConvert(operandPerformance.type).has_value(); }); if (!validOperandTypes) { return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in " @@ -534,6 +509,11 @@ GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHan return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle))); } +GeneralResult<std::vector<Operation>> unvalidatedConvert( + const std::vector<aidl_hal::Operation>& operations) { + return unvalidatedConvertVec(operations); +} + GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) { auto duplicatedFd = NN_TRY(dupFd(syncFence.get())); return SyncFence::create(std::move(duplicatedFd)); @@ -564,22 +544,14 @@ GeneralResult<Model> convert(const aidl_hal::Model& model) { return validatedConvert(model); } -GeneralResult<Operand> convert(const aidl_hal::Operand& operand) { - return unvalidatedConvert(operand); -} - GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) { - return unvalidatedConvert(operandType); + return validatedConvert(operandType); } GeneralResult<Priority> convert(const aidl_hal::Priority& priority) { return validatedConvert(priority); } -GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool) { - return unvalidatedConvert(memoryPool); -} - GeneralResult<Request> convert(const aidl_hal::Request& request) { return validatedConvert(request); } @@ -589,17 +561,13 @@ GeneralResult<Timing> convert(const aidl_hal::Timing& timing) { } GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) { - return unvalidatedConvert(syncFence); + return validatedConvert(syncFence); } GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) { return validatedConvert(extension); } -GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& operations) { - return unvalidatedConvert(operations); -} - GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) { return validatedConvert(memories); } @@ -644,14 +612,7 @@ nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConver template <typename Type> nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) { - const auto maybeVersion = nn::validate(canonical); - if (!maybeVersion.has_value()) { - return nn::error() << maybeVersion.error(); - } - const auto version = maybeVersion.value(); - if (version > kVersion) { - return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion; - } + NN_TRY(compliantVersion(canonical)); return utils::unvalidatedConvert(canonical); } @@ -797,6 +758,9 @@ nn::GeneralResult<ExecutionPreference> unvalidatedConvert( } nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) { + if (operandType == nn::OperandType::OEM || operandType == nn::OperandType::TENSOR_OEM_BYTE) { + return NN_ERROR() << "Unable to convert invalid OperandType " << operandType; + } return static_cast<OperandType>(operandType); } @@ -864,6 +828,9 @@ nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) { } nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) { + if (operationType == nn::OperationType::OEM_OPERATION) { + return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION"; + } return static_cast<OperationType>(operationType); } @@ -1004,7 +971,7 @@ nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache( } nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) { - return unvalidatedConvert(cacheToken); + return validatedConvert(cacheToken); } nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) { @@ -1076,7 +1043,7 @@ nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert( nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert( const std::vector<nn::SyncFence>& syncFences) { - return unvalidatedConvert(syncFences); + return validatedConvert(syncFences); } nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) { |