diff options
Diffstat (limited to 'neuralnetworks/utils/common/src/CommonUtils.cpp')
-rw-r--r-- | neuralnetworks/utils/common/src/CommonUtils.cpp | 89 |
1 files changed, 40 insertions, 49 deletions
diff --git a/neuralnetworks/utils/common/src/CommonUtils.cpp b/neuralnetworks/utils/common/src/CommonUtils.cpp index 4d26795d89..eaeb9ad872 100644 --- a/neuralnetworks/utils/common/src/CommonUtils.cpp +++ b/neuralnetworks/utils/common/src/CommonUtils.cpp @@ -200,10 +200,31 @@ nn::GeneralResult<std::reference_wrapper<const nn::Model>> flushDataFromPointerT return **maybeModelInSharedOut; } -nn::GeneralResult<std::reference_wrapper<const nn::Request>> flushDataFromPointerToShared( - const nn::Request* request, std::optional<nn::Request>* maybeRequestInSharedOut) { +template <> +void InputRelocationTracker::flush() const { + // Copy from pointers to shared memory. + uint8_t* memoryPtr = static_cast<uint8_t*>(std::get<void*>(kMapping.pointer)); + for (const auto& [data, length, offset] : kRelocationInfos) { + std::memcpy(memoryPtr + offset, data, length); + } +} + +template <> +void OutputRelocationTracker::flush() const { + // Copy from shared memory to pointers. + const uint8_t* memoryPtr = static_cast<const uint8_t*>( + std::visit([](auto ptr) { return static_cast<const void*>(ptr); }, kMapping.pointer)); + for (const auto& [data, length, offset] : kRelocationInfos) { + std::memcpy(data, memoryPtr + offset, length); + } +} + +nn::GeneralResult<std::reference_wrapper<const nn::Request>> convertRequestFromPointerToShared( + const nn::Request* request, std::optional<nn::Request>* maybeRequestInSharedOut, + RequestRelocation* relocationOut) { CHECK(request != nullptr); CHECK(maybeRequestInSharedOut != nullptr); + CHECK(relocationOut != nullptr); if (hasNoPointerData(*request)) { return *request; @@ -213,8 +234,11 @@ nn::GeneralResult<std::reference_wrapper<const nn::Request>> flushDataFromPointe // to the caller through `maybeRequestInSharedOut` if the function succeeds. nn::Request requestInShared = *request; + RequestRelocation relocation; + // Change input pointers to shared memory. - nn::ConstantMemoryBuilder inputBuilder(requestInShared.pools.size()); + nn::MutableMemoryBuilder inputBuilder(requestInShared.pools.size()); + std::vector<InputRelocationInfo> inputRelocationInfos; for (auto& input : requestInShared.inputs) { const auto& location = input.location; if (input.lifetime != nn::Request::Argument::LifeTime::POINTER) { @@ -225,17 +249,21 @@ nn::GeneralResult<std::reference_wrapper<const nn::Request>> flushDataFromPointe const void* data = std::visit([](auto ptr) { return static_cast<const void*>(ptr); }, location.pointer); CHECK(data != nullptr); - input.location = inputBuilder.append(data, location.length); + input.location = inputBuilder.append(location.length); + inputRelocationInfos.push_back({data, input.location.length, input.location.offset}); } // Allocate input memory. if (!inputBuilder.empty()) { auto memory = NN_TRY(inputBuilder.finish()); - requestInShared.pools.push_back(std::move(memory)); + requestInShared.pools.push_back(memory); + relocation.input = NN_TRY( + InputRelocationTracker::create(std::move(inputRelocationInfos), std::move(memory))); } // Change output pointers to shared memory. nn::MutableMemoryBuilder outputBuilder(requestInShared.pools.size()); + std::vector<OutputRelocationInfo> outputRelocationInfos; for (auto& output : requestInShared.outputs) { const auto& location = output.location; if (output.lifetime != nn::Request::Argument::LifeTime::POINTER) { @@ -243,62 +271,25 @@ nn::GeneralResult<std::reference_wrapper<const nn::Request>> flushDataFromPointe } output.lifetime = nn::Request::Argument::LifeTime::POOL; + void* data = std::get<void*>(location.pointer); + CHECK(data != nullptr); output.location = outputBuilder.append(location.length); + outputRelocationInfos.push_back({data, output.location.length, output.location.offset}); } // Allocate output memory. if (!outputBuilder.empty()) { auto memory = NN_TRY(outputBuilder.finish()); - requestInShared.pools.push_back(std::move(memory)); + requestInShared.pools.push_back(memory); + relocation.output = NN_TRY(OutputRelocationTracker::create(std::move(outputRelocationInfos), + std::move(memory))); } *maybeRequestInSharedOut = requestInShared; + *relocationOut = std::move(relocation); return **maybeRequestInSharedOut; } -nn::GeneralResult<void> unflushDataFromSharedToPointer( - const nn::Request& request, const std::optional<nn::Request>& maybeRequestInShared) { - if (!maybeRequestInShared.has_value() || maybeRequestInShared->pools.empty() || - !std::holds_alternative<nn::SharedMemory>(maybeRequestInShared->pools.back())) { - return {}; - } - const auto& requestInShared = *maybeRequestInShared; - - // Map the memory. - const auto& outputMemory = std::get<nn::SharedMemory>(requestInShared.pools.back()); - const auto [pointer, size, context] = NN_TRY(map(outputMemory)); - const uint8_t* constantPointer = - std::visit([](const auto& o) { return static_cast<const uint8_t*>(o); }, pointer); - - // Flush each output pointer. - CHECK_EQ(request.outputs.size(), requestInShared.outputs.size()); - for (size_t i = 0; i < request.outputs.size(); ++i) { - const auto& location = request.outputs[i].location; - const auto& locationInShared = requestInShared.outputs[i].location; - if (!std::holds_alternative<void*>(location.pointer)) { - continue; - } - - // Get output pointer and size. - void* data = std::get<void*>(location.pointer); - CHECK(data != nullptr); - const size_t length = location.length; - - // Get output pool location. - CHECK(requestInShared.outputs[i].lifetime == nn::Request::Argument::LifeTime::POOL); - const size_t index = locationInShared.poolIndex; - const size_t offset = locationInShared.offset; - const size_t outputPoolIndex = requestInShared.pools.size() - 1; - CHECK(locationInShared.length == length); - CHECK(index == outputPoolIndex); - - // Flush memory. - std::memcpy(data, constantPointer + offset, length); - } - - return {}; -} - nn::GeneralResult<std::vector<uint32_t>> countNumberOfConsumers( size_t numberOfOperands, const std::vector<nn::Operation>& operations) { return makeGeneralFailure(nn::countNumberOfConsumers(numberOfOperands, operations)); |