diff options
Diffstat (limited to 'neuralnetworks/1.0/vts/functional/Utils.cpp')
-rw-r--r-- | neuralnetworks/1.0/vts/functional/Utils.cpp | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/neuralnetworks/1.0/vts/functional/Utils.cpp b/neuralnetworks/1.0/vts/functional/Utils.cpp index 3613e69088..32850b060c 100644 --- a/neuralnetworks/1.0/vts/functional/Utils.cpp +++ b/neuralnetworks/1.0/vts/functional/Utils.cpp @@ -29,7 +29,11 @@ #include <gtest/gtest.h> #include <algorithm> +#include <cstring> +#include <functional> #include <iostream> +#include <map> +#include <numeric> #include <vector> namespace android::hardware::neuralnetworks { @@ -172,6 +176,45 @@ std::vector<TestBuffer> ExecutionContext::getOutputBuffers(const Request& reques return outputBuffers; } +uint32_t sizeOfData(V1_0::OperandType type) { + switch (type) { + case V1_0::OperandType::FLOAT32: + case V1_0::OperandType::INT32: + case V1_0::OperandType::UINT32: + case V1_0::OperandType::TENSOR_FLOAT32: + case V1_0::OperandType::TENSOR_INT32: + return 4; + case V1_0::OperandType::TENSOR_QUANT8_ASYMM: + return 1; + default: + CHECK(false) << "Invalid OperandType " << static_cast<uint32_t>(type); + return 0; + } +} + +static bool isTensor(V1_0::OperandType type) { + switch (type) { + case V1_0::OperandType::FLOAT32: + case V1_0::OperandType::INT32: + case V1_0::OperandType::UINT32: + return false; + case V1_0::OperandType::TENSOR_FLOAT32: + case V1_0::OperandType::TENSOR_INT32: + case V1_0::OperandType::TENSOR_QUANT8_ASYMM: + return true; + default: + CHECK(false) << "Invalid OperandType " << static_cast<uint32_t>(type); + return false; + } +} + +uint32_t sizeOfData(const V1_0::Operand& operand) { + const uint32_t dataSize = sizeOfData(operand.type); + if (isTensor(operand.type) && operand.dimensions.size() == 0) return 0; + return std::accumulate(operand.dimensions.begin(), operand.dimensions.end(), dataSize, + std::multiplies<>{}); +} + std::string gtestCompliantName(std::string name) { // gtest test names must only contain alphanumeric characters std::replace_if( |