diff options
Diffstat (limited to 'neuralnetworks/utils/adapter/aidl/src/Burst.cpp')
-rw-r--r-- | neuralnetworks/utils/adapter/aidl/src/Burst.cpp | 28 |
1 files changed, 24 insertions, 4 deletions
diff --git a/neuralnetworks/utils/adapter/aidl/src/Burst.cpp b/neuralnetworks/utils/adapter/aidl/src/Burst.cpp index 4fabb20635..a4a80faf2a 100644 --- a/neuralnetworks/utils/adapter/aidl/src/Burst.cpp +++ b/neuralnetworks/utils/adapter/aidl/src/Burst.cpp @@ -93,7 +93,8 @@ std::vector<nn::IBurst::OptionalCacheHold> ensureAllMemoriesAreCached( nn::ExecutionResult<ExecutionResult> executeSynchronously( const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache, const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measureTiming, int64_t deadlineNs, - int64_t loopTimeoutDurationNs) { + int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints, + const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) { if (request.pools.size() != memoryIdentifierTokens.size()) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "request.pools.size() != memoryIdentifierTokens.size()"; @@ -107,11 +108,13 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously( const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO; const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs)); const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs)); + auto nnHints = NN_TRY(convertInput(hints)); + auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix)); const auto hold = ensureAllMemoriesAreCached(&nnRequest, memoryIdentifierTokens, burst, cache); - const auto result = - burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration); + const auto result = burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, + nnHints, nnExtensionNameToPrefix); if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) { const auto& [message, code, outputShapes] = result.error(); @@ -155,7 +158,24 @@ ndk::ScopedAStatus Burst::executeSynchronously(const Request& request, ExecutionResult* executionResult) { auto result = adapter::executeSynchronously(*kBurst, kMemoryCache, request, memoryIdentifierTokens, - measureTiming, deadlineNs, loopTimeoutDurationNs); + measureTiming, deadlineNs, loopTimeoutDurationNs, {}, {}); + if (!result.has_value()) { + auto [message, code, _] = std::move(result).error(); + const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); + return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( + static_cast<int32_t>(aidlCode), message.c_str()); + } + *executionResult = std::move(result).value(); + return ndk::ScopedAStatus::ok(); +} + +ndk::ScopedAStatus Burst::executeSynchronouslyWithConfig( + const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, + const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) { + auto result = adapter::executeSynchronously( + *kBurst, kMemoryCache, request, memoryIdentifierTokens, config.measureTiming, + deadlineNs, config.loopTimeoutDurationNs, config.executionHints, + config.extensionNameToPrefix); if (!result.has_value()) { auto [message, code, _] = std::move(result).error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); |