diff options
Diffstat (limited to 'neuralnetworks/aidl/utils/src/Conversions.cpp')
-rw-r--r-- | neuralnetworks/aidl/utils/src/Conversions.cpp | 102 |
1 files changed, 35 insertions, 67 deletions
diff --git a/neuralnetworks/aidl/utils/src/Conversions.cpp b/neuralnetworks/aidl/utils/src/Conversions.cpp index 45bc005e9f..d5f7f81663 100644 --- a/neuralnetworks/aidl/utils/src/Conversions.cpp +++ b/neuralnetworks/aidl/utils/src/Conversions.cpp @@ -41,6 +41,8 @@ #include <type_traits> #include <utility> +#include "Utils.h" + #define VERIFY_NON_NEGATIVE(value) \ while (UNLIKELY(value < 0)) return NN_ERROR() @@ -53,7 +55,6 @@ constexpr std::underlying_type_t<Type> underlyingType(Type value) { return static_cast<std::underlying_type_t<Type>>(value); } -constexpr auto kVersion = android::nn::Version::ANDROID_S; constexpr int64_t kNoTiming = -1; } // namespace @@ -63,32 +64,6 @@ namespace { using ::aidl::android::hardware::common::NativeHandle; -constexpr auto validOperandType(nn::OperandType operandType) { - switch (operandType) { - case nn::OperandType::FLOAT32: - case nn::OperandType::INT32: - case nn::OperandType::UINT32: - case nn::OperandType::TENSOR_FLOAT32: - case nn::OperandType::TENSOR_INT32: - case nn::OperandType::TENSOR_QUANT8_ASYMM: - case nn::OperandType::BOOL: - case nn::OperandType::TENSOR_QUANT16_SYMM: - case nn::OperandType::TENSOR_FLOAT16: - case nn::OperandType::TENSOR_BOOL8: - case nn::OperandType::FLOAT16: - case nn::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: - case nn::OperandType::TENSOR_QUANT16_ASYMM: - case nn::OperandType::TENSOR_QUANT8_SYMM: - case nn::OperandType::TENSOR_QUANT8_ASYMM_SIGNED: - case nn::OperandType::SUBGRAPH: - return true; - case nn::OperandType::OEM: - case nn::OperandType::TENSOR_OEM_BYTE: - return false; - } - return nn::isExtension(operandType); -} - template <typename Input> using UnvalidatedConvertOutput = std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>; @@ -113,14 +88,7 @@ GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert( template <typename Type> GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) { auto canonical = NN_TRY(nn::unvalidatedConvert(halObject)); - const auto maybeVersion = validate(canonical); - if (!maybeVersion.has_value()) { - return error() << maybeVersion.error(); - } - const auto version = maybeVersion.value(); - if (version > kVersion) { - return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion; - } + NN_TRY(aidl_hal::utils::compliantVersion(canonical)); return canonical; } @@ -185,13 +153,21 @@ static GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const Native GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) { VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed."; - return static_cast<OperandType>(operandType); + const auto canonical = static_cast<OperandType>(operandType); + if (canonical == OperandType::OEM || canonical == OperandType::TENSOR_OEM_BYTE) { + return NN_ERROR() << "Unable to convert invalid OperandType " << canonical; + } + return canonical; } GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) { VERIFY_NON_NEGATIVE(underlyingType(operationType)) << "Negative operation types are not allowed."; - return static_cast<OperationType>(operationType); + const auto canonical = static_cast<OperationType>(operationType); + if (canonical == OperationType::OEM_OPERATION) { + return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION"; + } + return canonical; } GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) { @@ -206,8 +182,7 @@ GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& cap const bool validOperandTypes = std::all_of( capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(), [](const aidl_hal::OperandPerformance& operandPerformance) { - const auto maybeType = unvalidatedConvert(operandPerformance.type); - return !maybeType.has_value() ? false : validOperandType(maybeType.value()); + return validatedConvert(operandPerformance.type).has_value(); }); if (!validOperandTypes) { return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in " @@ -534,6 +509,11 @@ GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHan return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle))); } +GeneralResult<std::vector<Operation>> unvalidatedConvert( + const std::vector<aidl_hal::Operation>& operations) { + return unvalidatedConvertVec(operations); +} + GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) { auto duplicatedFd = NN_TRY(dupFd(syncFence.get())); return SyncFence::create(std::move(duplicatedFd)); @@ -564,22 +544,14 @@ GeneralResult<Model> convert(const aidl_hal::Model& model) { return validatedConvert(model); } -GeneralResult<Operand> convert(const aidl_hal::Operand& operand) { - return unvalidatedConvert(operand); -} - GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) { - return unvalidatedConvert(operandType); + return validatedConvert(operandType); } GeneralResult<Priority> convert(const aidl_hal::Priority& priority) { return validatedConvert(priority); } -GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool) { - return unvalidatedConvert(memoryPool); -} - GeneralResult<Request> convert(const aidl_hal::Request& request) { return validatedConvert(request); } @@ -589,17 +561,13 @@ GeneralResult<Timing> convert(const aidl_hal::Timing& timing) { } GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) { - return unvalidatedConvert(syncFence); + return validatedConvert(syncFence); } GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) { return validatedConvert(extension); } -GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& operations) { - return unvalidatedConvert(operations); -} - GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) { return validatedConvert(memories); } @@ -644,14 +612,7 @@ nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConver template <typename Type> nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) { - const auto maybeVersion = nn::validate(canonical); - if (!maybeVersion.has_value()) { - return nn::error() << maybeVersion.error(); - } - const auto version = maybeVersion.value(); - if (version > kVersion) { - return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion; - } + NN_TRY(compliantVersion(canonical)); return utils::unvalidatedConvert(canonical); } @@ -797,6 +758,9 @@ nn::GeneralResult<ExecutionPreference> unvalidatedConvert( } nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) { + if (operandType == nn::OperandType::OEM || operandType == nn::OperandType::TENSOR_OEM_BYTE) { + return NN_ERROR() << "Unable to convert invalid OperandType " << operandType; + } return static_cast<OperandType>(operandType); } @@ -864,6 +828,9 @@ nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) { } nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) { + if (operationType == nn::OperationType::OEM_OPERATION) { + return NN_ERROR() << "Unable to convert invalid OperationType OEM_OPERATION"; + } return static_cast<OperationType>(operationType); } @@ -964,11 +931,12 @@ nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) { } nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration) { - const uint64_t nanoseconds = duration.count(); - if (nanoseconds > std::numeric_limits<int64_t>::max()) { - return std::numeric_limits<int64_t>::max(); + if (duration < nn::Duration::zero()) { + return NN_ERROR() << "Unable to convert invalid (negative) duration"; } - return static_cast<int64_t>(nanoseconds); + constexpr std::chrono::nanoseconds::rep kIntMax = std::numeric_limits<int64_t>::max(); + const auto count = duration.count(); + return static_cast<int64_t>(std::min(count, kIntMax)); } nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration) { @@ -1004,7 +972,7 @@ nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache( } nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) { - return unvalidatedConvert(cacheToken); + return validatedConvert(cacheToken); } nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) { @@ -1076,7 +1044,7 @@ nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert( nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert( const std::vector<nn::SyncFence>& syncFences) { - return unvalidatedConvert(syncFences); + return validatedConvert(syncFences); } nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) { |