diff options
author | Haamed Gheibi <haamed@google.com> | 2022-02-09 14:35:06 -0800 |
---|---|---|
committer | Haamed Gheibi <haamed@google.com> | 2022-02-09 14:41:16 -0800 |
commit | ab52181d73b04e131fd72e32d69b5123a5d6892b (patch) | |
tree | 0ac86b537180b6fb97716b3058dfae44af9eaac7 /neuralnetworks/utils/adapter/aidl/src/Burst.cpp | |
parent | f99b35c293439db0b7436b47b939eb8c7bf21b51 (diff) | |
parent | 4d2548cfa7b86b79a516be9b60f6b666cc9af682 (diff) |
Merge TP1A.220126.001
Change-Id: Ibf6bd2c20d9927fde8b2a05dde2b58bd8faea20f
Diffstat (limited to 'neuralnetworks/utils/adapter/aidl/src/Burst.cpp')
-rw-r--r-- | neuralnetworks/utils/adapter/aidl/src/Burst.cpp | 28 |
1 files changed, 24 insertions, 4 deletions
diff --git a/neuralnetworks/utils/adapter/aidl/src/Burst.cpp b/neuralnetworks/utils/adapter/aidl/src/Burst.cpp index 4fabb20635..a4a80faf2a 100644 --- a/neuralnetworks/utils/adapter/aidl/src/Burst.cpp +++ b/neuralnetworks/utils/adapter/aidl/src/Burst.cpp @@ -93,7 +93,8 @@ std::vector<nn::IBurst::OptionalCacheHold> ensureAllMemoriesAreCached( nn::ExecutionResult<ExecutionResult> executeSynchronously( const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache, const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measureTiming, int64_t deadlineNs, - int64_t loopTimeoutDurationNs) { + int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints, + const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) { if (request.pools.size() != memoryIdentifierTokens.size()) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "request.pools.size() != memoryIdentifierTokens.size()"; @@ -107,11 +108,13 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously( const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO; const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs)); const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs)); + auto nnHints = NN_TRY(convertInput(hints)); + auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix)); const auto hold = ensureAllMemoriesAreCached(&nnRequest, memoryIdentifierTokens, burst, cache); - const auto result = - burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration); + const auto result = burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, + nnHints, nnExtensionNameToPrefix); if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) { const auto& [message, code, outputShapes] = result.error(); @@ -155,7 +158,24 @@ ndk::ScopedAStatus Burst::executeSynchronously(const Request& request, ExecutionResult* executionResult) { auto result = adapter::executeSynchronously(*kBurst, kMemoryCache, request, memoryIdentifierTokens, - measureTiming, deadlineNs, loopTimeoutDurationNs); + measureTiming, deadlineNs, loopTimeoutDurationNs, {}, {}); + if (!result.has_value()) { + auto [message, code, _] = std::move(result).error(); + const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); + return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( + static_cast<int32_t>(aidlCode), message.c_str()); + } + *executionResult = std::move(result).value(); + return ndk::ScopedAStatus::ok(); +} + +ndk::ScopedAStatus Burst::executeSynchronouslyWithConfig( + const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, + const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) { + auto result = adapter::executeSynchronously( + *kBurst, kMemoryCache, request, memoryIdentifierTokens, config.measureTiming, + deadlineNs, config.loopTimeoutDurationNs, config.executionHints, + config.extensionNameToPrefix); if (!result.has_value()) { auto [message, code, _] = std::move(result).error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); |