diff options
Diffstat (limited to 'neuralnetworks/1.2/utils/src/Conversions.cpp')
-rw-r--r-- | neuralnetworks/1.2/utils/src/Conversions.cpp | 92 |
1 files changed, 21 insertions, 71 deletions
diff --git a/neuralnetworks/1.2/utils/src/Conversions.cpp b/neuralnetworks/1.2/utils/src/Conversions.cpp index 2c45583d0c..29945b75e5 100644 --- a/neuralnetworks/1.2/utils/src/Conversions.cpp +++ b/neuralnetworks/1.2/utils/src/Conversions.cpp @@ -37,6 +37,8 @@ #include <type_traits> #include <utility> +#include "Utils.h" + namespace { template <typename Type> @@ -45,50 +47,23 @@ constexpr std::underlying_type_t<Type> underlyingType(Type value) { } using HalDuration = std::chrono::duration<uint64_t, std::micro>; -constexpr auto kVersion = android::nn::Version::ANDROID_Q; -constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max(); } // namespace namespace android::nn { namespace { -constexpr bool validOperandType(OperandType operandType) { - switch (operandType) { - case OperandType::FLOAT32: - case OperandType::INT32: - case OperandType::UINT32: - case OperandType::TENSOR_FLOAT32: - case OperandType::TENSOR_INT32: - case OperandType::TENSOR_QUANT8_ASYMM: - case OperandType::BOOL: - case OperandType::TENSOR_QUANT16_SYMM: - case OperandType::TENSOR_FLOAT16: - case OperandType::TENSOR_BOOL8: - case OperandType::FLOAT16: - case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: - case OperandType::TENSOR_QUANT16_ASYMM: - case OperandType::TENSOR_QUANT8_SYMM: - case OperandType::OEM: - case OperandType::TENSOR_OEM_BYTE: - return true; - default: - break; - } - return isExtension(operandType); -} - using hardware::hidl_handle; using hardware::hidl_vec; template <typename Input> -using unvalidatedConvertOutput = +using UnvalidatedConvertOutput = std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>; template <typename Type> -GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec( +GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert( const hidl_vec<Type>& arguments) { - std::vector<unvalidatedConvertOutput<Type>> canonical; + std::vector<UnvalidatedConvertOutput<Type>> canonical; canonical.reserve(arguments.size()); for (const auto& argument : arguments) { canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument))); @@ -97,29 +72,16 @@ GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec } template <typename Type> -GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert( - const hidl_vec<Type>& arguments) { - return unvalidatedConvertVec(arguments); -} - -template <typename Type> -decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) { +GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) { auto canonical = NN_TRY(nn::unvalidatedConvert(halObject)); - const auto maybeVersion = validate(canonical); - if (!maybeVersion.has_value()) { - return error() << maybeVersion.error(); - } - const auto version = maybeVersion.value(); - if (version > kVersion) { - return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion; - } + NN_TRY(hal::V1_2::utils::compliantVersion(canonical)); return canonical; } template <typename Type> -GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> validatedConvert( +GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert( const hidl_vec<Type>& arguments) { - std::vector<unvalidatedConvertOutput<Type>> canonical; + std::vector<UnvalidatedConvertOutput<Type>> canonical; canonical.reserve(arguments.size()); for (const auto& argument : arguments) { canonical.push_back(NN_TRY(validatedConvert(argument))); @@ -145,8 +107,7 @@ GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_2::Capabilities& ca const bool validOperandTypes = std::all_of( capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(), [](const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) { - const auto maybeType = unvalidatedConvert(operandPerformance.type); - return !maybeType.has_value() ? false : validOperandType(maybeType.value()); + return validatedConvert(operandPerformance.type).has_value(); }); if (!validOperandTypes) { return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) @@ -275,6 +236,7 @@ GeneralResult<MeasureTiming> unvalidatedConvert(const hal::V1_2::MeasureTiming& GeneralResult<Timing> unvalidatedConvert(const hal::V1_2::Timing& timing) { constexpr uint64_t kMaxTiming = std::chrono::floor<HalDuration>(Duration::max()).count(); constexpr auto convertTiming = [](uint64_t halTiming) -> OptionalDuration { + constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max(); if (halTiming == kNoTiming) { return {}; } @@ -378,25 +340,19 @@ nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory } template <typename Input> -using unvalidatedConvertOutput = +using UnvalidatedConvertOutput = std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>; template <typename Type> -nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec( +nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert( const std::vector<Type>& arguments) { - hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size()); + hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { halObject[i] = NN_TRY(unvalidatedConvert(arguments[i])); } return halObject; } -template <typename Type> -nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert( - const std::vector<Type>& arguments) { - return unvalidatedConvertVec(arguments); -} - nn::GeneralResult<Operand::ExtraParams> makeExtraParams(nn::Operand::NoParams /*noParams*/) { return Operand::ExtraParams{}; } @@ -416,22 +372,15 @@ nn::GeneralResult<Operand::ExtraParams> makeExtraParams( } template <typename Type> -decltype(utils::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) { - const auto maybeVersion = nn::validate(canonical); - if (!maybeVersion.has_value()) { - return nn::error() << maybeVersion.error(); - } - const auto version = maybeVersion.value(); - if (version > kVersion) { - return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion; - } - return utils::unvalidatedConvert(canonical); +nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) { + NN_TRY(compliantVersion(canonical)); + return unvalidatedConvert(canonical); } template <typename Type> -nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> validatedConvert( +nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert( const std::vector<Type>& arguments) { - hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size()); + hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { halObject[i] = NN_TRY(validatedConvert(arguments[i])); } @@ -469,7 +418,7 @@ nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capab capabilities.operandPerformance.asVector().end(), std::back_inserter(operandPerformance), [](const nn::Capabilities::OperandPerformance& operandPerformance) { - return nn::validOperandType(operandPerformance.type); + return compliantVersion(operandPerformance.type).has_value(); }); return Capabilities{ @@ -570,6 +519,7 @@ nn::GeneralResult<MeasureTiming> unvalidatedConvert(const nn::MeasureTiming& mea nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) { constexpr auto convertTiming = [](nn::OptionalDuration canonicalTiming) -> uint64_t { + constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max(); if (!canonicalTiming.has_value()) { return kNoTiming; } |