summaryrefslogtreecommitdiff
path: root/neuralnetworks/utils/adapter/aidl/src/Burst.cpp
diff options
context:
space:
mode:
authorHaamed Gheibi <haamed@google.com>2022-02-09 14:35:06 -0800
committerHaamed Gheibi <haamed@google.com>2022-02-09 14:41:16 -0800
commitab52181d73b04e131fd72e32d69b5123a5d6892b (patch)
tree0ac86b537180b6fb97716b3058dfae44af9eaac7 /neuralnetworks/utils/adapter/aidl/src/Burst.cpp
parentf99b35c293439db0b7436b47b939eb8c7bf21b51 (diff)
parent4d2548cfa7b86b79a516be9b60f6b666cc9af682 (diff)
Merge TP1A.220126.001
Change-Id: Ibf6bd2c20d9927fde8b2a05dde2b58bd8faea20f
Diffstat (limited to 'neuralnetworks/utils/adapter/aidl/src/Burst.cpp')
-rw-r--r--neuralnetworks/utils/adapter/aidl/src/Burst.cpp28
1 files changed, 24 insertions, 4 deletions
diff --git a/neuralnetworks/utils/adapter/aidl/src/Burst.cpp b/neuralnetworks/utils/adapter/aidl/src/Burst.cpp
index 4fabb20635..a4a80faf2a 100644
--- a/neuralnetworks/utils/adapter/aidl/src/Burst.cpp
+++ b/neuralnetworks/utils/adapter/aidl/src/Burst.cpp
@@ -93,7 +93,8 @@ std::vector<nn::IBurst::OptionalCacheHold> ensureAllMemoriesAreCached(
nn::ExecutionResult<ExecutionResult> executeSynchronously(
const nn::IBurst& burst, const Burst::ThreadSafeMemoryCache& cache, const Request& request,
const std::vector<int64_t>& memoryIdentifierTokens, bool measureTiming, int64_t deadlineNs,
- int64_t loopTimeoutDurationNs) {
+ int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& hints,
+ const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
if (request.pools.size() != memoryIdentifierTokens.size()) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
<< "request.pools.size() != memoryIdentifierTokens.size()";
@@ -107,11 +108,13 @@ nn::ExecutionResult<ExecutionResult> executeSynchronously(
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
+ auto nnHints = NN_TRY(convertInput(hints));
+ auto nnExtensionNameToPrefix = NN_TRY(convertInput(extensionNameToPrefix));
const auto hold = ensureAllMemoriesAreCached(&nnRequest, memoryIdentifierTokens, burst, cache);
- const auto result =
- burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration);
+ const auto result = burst.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration,
+ nnHints, nnExtensionNameToPrefix);
if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
const auto& [message, code, outputShapes] = result.error();
@@ -155,7 +158,24 @@ ndk::ScopedAStatus Burst::executeSynchronously(const Request& request,
ExecutionResult* executionResult) {
auto result =
adapter::executeSynchronously(*kBurst, kMemoryCache, request, memoryIdentifierTokens,
- measureTiming, deadlineNs, loopTimeoutDurationNs);
+ measureTiming, deadlineNs, loopTimeoutDurationNs, {}, {});
+ if (!result.has_value()) {
+ auto [message, code, _] = std::move(result).error();
+ const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
+ return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
+ static_cast<int32_t>(aidlCode), message.c_str());
+ }
+ *executionResult = std::move(result).value();
+ return ndk::ScopedAStatus::ok();
+}
+
+ndk::ScopedAStatus Burst::executeSynchronouslyWithConfig(
+ const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
+ const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
+ auto result = adapter::executeSynchronously(
+ *kBurst, kMemoryCache, request, memoryIdentifierTokens, config.measureTiming,
+ deadlineNs, config.loopTimeoutDurationNs, config.executionHints,
+ config.extensionNameToPrefix);
if (!result.has_value()) {
auto [message, code, _] = std::move(result).error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);