summaryrefslogtreecommitdiff
path: root/system/gd/rust/gddi/src/lib.rs
blob: 1e005431f079b61a127c13153fc39e1eadd3cd17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! Core dependency injection objects

use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;

pub use gddi_macros::{module, part_out, provides, Stoppable};

type InstanceBox = Box<dyn Any + Send + Sync>;
/// A box around a future for a provider that is safe to send between threads
pub type ProviderFutureBox = Box<dyn Future<Output = Box<dyn Any>> + Send + Sync>;
type ProviderFnBox = Box<dyn Fn(Arc<Registry>) -> Pin<ProviderFutureBox> + Send + Sync>;

/// Called to stop an injected object
pub trait Stoppable {
    /// Stop and close all resources
    fn stop(&self) {}
}

/// Builder for Registry
pub struct RegistryBuilder {
    providers: HashMap<TypeId, Provider>,
}

/// Keeps track of central injection state
pub struct Registry {
    providers: Arc<Mutex<HashMap<TypeId, Provider>>>,
    instances: Arc<Mutex<HashMap<TypeId, InstanceBox>>>,
    start_order: Arc<Mutex<Vec<Box<dyn Stoppable + Send + Sync>>>>,
}

#[derive(Clone)]
struct Provider {
    f: Arc<ProviderFnBox>,
}

impl Default for RegistryBuilder {
    fn default() -> Self {
        Self::new()
    }
}

impl RegistryBuilder {
    /// Creates a new RegistryBuilder
    pub fn new() -> Self {
        RegistryBuilder { providers: HashMap::new() }
    }

    /// Registers a module with this registry
    pub fn register_module<F>(self, init: F) -> Self
    where
        F: Fn(Self) -> Self,
    {
        init(self)
    }

    /// Registers a provider function with this registry
    pub fn register_provider<T: 'static>(mut self, f: ProviderFnBox) -> Self {
        self.providers.insert(TypeId::of::<T>(), Provider { f: Arc::new(f) });

        self
    }

    /// Construct the Registry from this builder
    pub fn build(self) -> Registry {
        Registry {
            providers: Arc::new(Mutex::new(self.providers)),
            instances: Arc::new(Mutex::new(HashMap::new())),
            start_order: Arc::new(Mutex::new(Vec::new())),
        }
    }
}

impl Registry {
    /// Gets an instance of a type, implicitly starting any dependencies if necessary
    pub async fn get<T: 'static + Clone + Send + Sync + Stoppable>(self: &Arc<Self>) -> T {
        let typeid = TypeId::of::<T>();
        {
            let instances = self.instances.lock().await;
            if let Some(value) = instances.get(&typeid) {
                return value.downcast_ref::<T>().expect("was not correct type").clone();
            }
        }

        let casted = {
            let provider = { self.providers.lock().await[&typeid].clone() };
            let result = (provider.f)(self.clone()).await;
            (*result.downcast::<T>().expect("was not correct type")).clone()
        };

        let mut instances = self.instances.lock().await;
        instances.insert(typeid, Box::new(casted.clone()));

        let mut start_order = self.start_order.lock().await;
        start_order.push(Box::new(casted.clone()));

        casted
    }

    /// Inject an already created instance of T. Useful for config.
    pub async fn inject<T: 'static + Clone + Send + Sync>(self: &Arc<Self>, obj: T) {
        let mut instances = self.instances.lock().await;
        instances.insert(TypeId::of::<T>(), Box::new(obj));
    }

    /// Stop all instances, in reverse order of start.
    pub async fn stop_all(self: &Arc<Self>) {
        let mut start_order = self.start_order.lock().await;
        while let Some(obj) = start_order.pop() {
            obj.stop();
        }
        self.instances.lock().await.clear();
    }
}

impl<T> Stoppable for std::sync::Arc<T> {}