summaryrefslogtreecommitdiff
path: root/neuralnetworks/utils/common/include/nnapi/hal/TransferValue.h
blob: 6679afefecd5c39e0a29e632157de9c90647d10c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_TRANSFER_VALUE_H
#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_TRANSFER_VALUE_H

#include <android-base/logging.h>
#include <android-base/thread_annotations.h>

#include <condition_variable>
#include <functional>
#include <mutex>
#include <optional>
#include <type_traits>

namespace android::hardware::neuralnetworks::utils {

// This class adapts a function pointer and offers two affordances:
// 1) This class object can be used to generate a callback (via the implicit conversion operator)
//    that can be used to send the result to `CallbackValue` when called.
// 2) This class object can be used to retrieve the result of the callback with `take`.
//
// This class is thread compatible.
template <typename ReturnType, typename... ArgTypes>
class CallbackValue final {
  public:
    using FunctionType = std::add_pointer_t<ReturnType(ArgTypes...)>;
    using CallbackType = std::function<void(ArgTypes...)>;

    explicit CallbackValue(FunctionType fn);

    // Creates a callback that forwards its arguments to `mFunction` and stores the result in
    // `mReturnValue`.
    /*implicit*/ operator CallbackType();  // NOLINT(google-explicit-constructor)

    // Take the result of calling `mFunction`.
    // Precondition: mReturnValue.has_value()
    // Postcondition: !mReturnValue.has_value()
    [[nodiscard]] ReturnType take();

  private:
    std::optional<ReturnType> mReturnValue;
    FunctionType mFunction;
};

// Deduction guidelines for CallbackValue when constructed with a function pointer.
template <typename ReturnType, typename... ArgTypes>
CallbackValue(ReturnType (*)(ArgTypes...))->CallbackValue<ReturnType, ArgTypes...>;

// Thread-safe container to pass a value between threads.
template <typename Type>
class TransferValue final {
  public:
    // Put the value in `TransferValue`. If `TransferValue` already has a value, this function is a
    // no-op.
    void put(Type object) const;

    // Take the value stored in `TransferValue`. If no value is available, this function will block
    // until the value becomes available.
    // Postcondition: !mObject.has_value()
    [[nodiscard]] Type take() const;

  private:
    mutable std::mutex mMutex;
    mutable std::condition_variable mCondition;
    mutable std::optional<Type> mObject GUARDED_BY(mMutex);
};

// template implementations

template <typename ReturnType, typename... ArgTypes>
CallbackValue<ReturnType, ArgTypes...>::CallbackValue(FunctionType fn) : mFunction(fn) {}

template <typename ReturnType, typename... ArgTypes>
CallbackValue<ReturnType, ArgTypes...>::operator CallbackType() {
    return [this](ArgTypes... args) { mReturnValue = mFunction(args...); };
}

template <typename ReturnType, typename... ArgTypes>
ReturnType CallbackValue<ReturnType, ArgTypes...>::take() {
    CHECK(mReturnValue.has_value());
    std::optional<ReturnType> object;
    std::swap(object, mReturnValue);
    return std::move(object).value();
}

template <typename Type>
void TransferValue<Type>::put(Type object) const {
    {
        std::lock_guard guard(mMutex);
        // Immediately return if value already exists.
        if (mObject.has_value()) return;
        mObject.emplace(std::move(object));
    }
    mCondition.notify_all();
}

template <typename Type>
Type TransferValue<Type>::take() const {
    std::unique_lock lock(mMutex);
    base::ScopedLockAssertion lockAssertion(mMutex);
    mCondition.wait(lock, [this]() REQUIRES(mMutex) { return mObject.has_value(); });
    CHECK(mObject.has_value());
    std::optional<Type> object;
    std::swap(object, mObject);
    return std::move(object).value();
}

}  // namespace android::hardware::neuralnetworks::utils

#endif  // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_UTILS_TRANSFER_VALUE_H