summaryrefslogtreecommitdiff
path: root/neuralnetworks/1.0/vts/functional/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'neuralnetworks/1.0/vts/functional/Utils.cpp')
-rw-r--r--neuralnetworks/1.0/vts/functional/Utils.cpp43
1 files changed, 43 insertions, 0 deletions
diff --git a/neuralnetworks/1.0/vts/functional/Utils.cpp b/neuralnetworks/1.0/vts/functional/Utils.cpp
index 3613e69088..32850b060c 100644
--- a/neuralnetworks/1.0/vts/functional/Utils.cpp
+++ b/neuralnetworks/1.0/vts/functional/Utils.cpp
@@ -29,7 +29,11 @@
#include <gtest/gtest.h>
#include <algorithm>
+#include <cstring>
+#include <functional>
#include <iostream>
+#include <map>
+#include <numeric>
#include <vector>
namespace android::hardware::neuralnetworks {
@@ -172,6 +176,45 @@ std::vector<TestBuffer> ExecutionContext::getOutputBuffers(const Request& reques
return outputBuffers;
}
+uint32_t sizeOfData(V1_0::OperandType type) {
+ switch (type) {
+ case V1_0::OperandType::FLOAT32:
+ case V1_0::OperandType::INT32:
+ case V1_0::OperandType::UINT32:
+ case V1_0::OperandType::TENSOR_FLOAT32:
+ case V1_0::OperandType::TENSOR_INT32:
+ return 4;
+ case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
+ return 1;
+ default:
+ CHECK(false) << "Invalid OperandType " << static_cast<uint32_t>(type);
+ return 0;
+ }
+}
+
+static bool isTensor(V1_0::OperandType type) {
+ switch (type) {
+ case V1_0::OperandType::FLOAT32:
+ case V1_0::OperandType::INT32:
+ case V1_0::OperandType::UINT32:
+ return false;
+ case V1_0::OperandType::TENSOR_FLOAT32:
+ case V1_0::OperandType::TENSOR_INT32:
+ case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
+ return true;
+ default:
+ CHECK(false) << "Invalid OperandType " << static_cast<uint32_t>(type);
+ return false;
+ }
+}
+
+uint32_t sizeOfData(const V1_0::Operand& operand) {
+ const uint32_t dataSize = sizeOfData(operand.type);
+ if (isTensor(operand.type) && operand.dimensions.size() == 0) return 0;
+ return std::accumulate(operand.dimensions.begin(), operand.dimensions.end(), dataSize,
+ std::multiplies<>{});
+}
+
std::string gtestCompliantName(std::string name) {
// gtest test names must only contain alphanumeric characters
std::replace_if(