summaryrefslogtreecommitdiff
path: root/common/src/reconnect.rs
blob: f569fc95b2f0f9586500c7cdace1e64970f6ebd2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
/*
  This file is part of TALER
  Copyright (C) 2022 Taler Systems SA

  TALER is free software; you can redistribute it and/or modify it under the
  terms of the GNU Affero General Public License as published by the Free Software
  Foundation; either version 3, or (at your option) any later version.

  TALER is distributed in the hope that it will be useful, but WITHOUT ANY
  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.

  You should have received a copy of the GNU Affero General Public License along with
  TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
*/
use std::time::Duration;

use log::error;
use postgres::{Client, NoTls};

const RECONNECT_DELAY: Duration = Duration::from_secs(5);

pub struct AutoReconnect<S, C> {
    config: S,
    client: C,
    connect: fn(&S) -> Option<C>,
    check: fn(&mut C) -> bool,
}

impl<S, C> AutoReconnect<S, C> {
    pub fn new(config: S, connect: fn(&S) -> Option<C>, check: fn(&mut C) -> bool) -> Self {
        Self {
            client: Self::connect(&config, connect),
            connect,
            check,
            config,
        }
    }

    /// Create a new client, loop on error
    fn connect(config: &S, connect: fn(&S) -> Option<C>) -> C {
        loop {
            match connect(config) {
                Some(new) => return new,
                None => std::thread::sleep(RECONNECT_DELAY),
            }
        }
    }

    /// Get a mutable connection, block until a connection can be established
    pub fn client(&mut self) -> &mut C {
        if (self.check)(&mut self.client) {
            self.client = Self::connect(&self.config, self.connect);
        }
        &mut self.client
    }
}

pub type AutoReconnectDb = AutoReconnect<postgres::Config, Client>;

pub fn auto_reconnect_db(config: postgres::Config) -> AutoReconnectDb {
    AutoReconnect::new(
        config,
        |config| {
            config
                .connect(NoTls)
                .map_err(|err| error!("connect DB: {}", err))
                .ok()
        },
        |client| client.is_valid(RECONNECT_DELAY).is_err(),
    )
}