From a0629343692a7f008e01f9e056330b2202783d90 Mon Sep 17 00:00:00 2001 From: Lyle Mantooth Date: Sat, 20 May 2023 22:42:03 -0400 Subject: [PATCH] Analytics! --- Cargo.toml | 3 +- flake.nix | 2 +- src/lib.rs | 116 +++++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 112 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 854b611..36707d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "locat" -version = "0.3.1" +version = "0.4.0" edition = "2021" publish = ["menteeth"] @@ -8,4 +8,5 @@ publish = ["menteeth"] [dependencies] maxminddb = "0.23" +rusqlite = "0.28" thiserror = "1" diff --git a/flake.nix b/flake.nix index 2f907a6..8e90e68 100644 --- a/flake.nix +++ b/flake.nix @@ -14,7 +14,7 @@ { defaultPackage = naersk-lib.buildPackage ./.; devShell = with pkgs; mkShell { - buildInputs = [ cargo rustc rustfmt pre-commit rustPackages.clippy ]; + buildInputs = [ cargo rustc rustfmt pre-commit rustPackages.clippy sqlite]; RUST_SRC_PATH = rustPlatform.rustLibSrc; }; }); diff --git a/src/lib.rs b/src/lib.rs index 7bb21c1..42a10c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,35 +2,137 @@ use std::net::IpAddr; /// Allows geo-locating IPs and keeps analytics. pub struct Locat { - geoip: maxminddb::Reader>, + reader: maxminddb::Reader>, + analytics: Db } #[derive(Debug, thiserror::Error)] pub enum Error { #[error("maxminddb error: {0}")] MaxMindDb(#[from] maxminddb::MaxMindDBError), + + #[error("rusqlite error: {0}")] + Rusqlite(#[from] rusqlite::Error), } impl Locat { - pub fn new(geoip_country_db_path: &str, _analytics_db_path: &str) -> Result { + pub fn new(geoip_country_db_path: &str, analytics_db_path: &str) -> Result { // Todo: create analytics db. Ok(Self { - geoip: maxminddb::Reader::open_readfile(geoip_country_db_path)?, + reader: maxminddb::Reader::open_readfile(geoip_country_db_path)?, + analytics: Db { + path: analytics_db_path.to_string(), + }, }) } /// Converts an address to an ISO 3166-1 alpha-2 country code. pub async fn ip_to_iso_code(&self, addr: IpAddr) -> Option<&str> { - self.geoip + let iso_code = self + .reader .lookup::(addr) .ok()? .country? - .iso_code + .iso_code?; + + if let Err(e) = self.analytics.increment(iso_code) { + eprintln!("Could not increment analytic: {e}"); + } + + Some(iso_code) } /// Returns a map of country codes to number of requests. - pub async fn get_analytics(&self) -> Vec<(String, u64)> { - Default::default() + pub async fn get_analytics(&self) -> Result, Error> { + Ok(self.analytics.list()?) + } +} + +struct Db { + path: String, +} + +impl Db { + fn list(&self) -> Result, rusqlite::Error> { + let conn = self.get_conn()?; + let mut stmt = conn.prepare("SELECT iso_code, count FROM analytics")?; + let mut rows = stmt.query([])?; + let mut analytics = Vec::new(); + while let Some(row) = rows.next()? { + let iso_code: String = row.get(0)?; + let count: u64 = row.get(1)?; + analytics.push((iso_code, count)); + } + Ok(analytics) + } + + fn increment(&self, iso_code: &str) -> Result<(), rusqlite::Error> { + let conn = self.get_conn().unwrap(); + let mut stmt = conn + .prepare("INSERT INTO analytics (iso_code, count) VALUES (?, 1) ON CONFLICT (iso_code) DO UPDATE SET count = count + 1")?; + stmt.execute([iso_code])?; + Ok(()) + } + + fn get_conn(&self) -> Result { + let conn = rusqlite::Connection::open(&self.path).unwrap(); + self.migrate(&conn)?; + Ok(conn) + } + + fn migrate(&self, conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { + // create analytics table. + conn.execute( + "CREATE TABLE IF NOT EXISTS analytics ( + iso_code TEXT PRIMARY KEY, + count INTEGER NOT NULL + )", + [], + )?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::Db; + + struct RemoveOnDrop { + path: String, + } + + impl Drop for RemoveOnDrop { + fn drop(&mut self) { + _ = std::fs::remove_file(&self.path); + } + } + + #[test] + fn test_db() { + let db = Db { + path: "/tmp/locat-test.db".to_string(), + }; + let _remove_on_drop = RemoveOnDrop { + path: db.path.clone(), + }; + + let analytics = db.list().unwrap(); + assert_eq!(analytics.len(), 0); + + db.increment("US").unwrap(); + let analytics = db.list().unwrap(); + assert_eq!(analytics.len(), 1); + + db.increment("US").unwrap(); + db.increment("FR").unwrap(); + let analytics = db.list().unwrap(); + assert_eq!(analytics.len(), 2); + // contains US at count 2 + assert!(analytics.contains(&("US".to_string(), 2))); + // contains FR at count 1 + assert!(analytics.contains(&("FR".to_string(), 1))); + // does not contain DE + assert!(!analytics.contains(&("DE".to_string(), 0))); } }