robocop

Checks KYC attributes against sanction lists
Log | Files | Refs | Submodules | README | LICENSE

main.rs (9476B)


      1 // This file is part of Robocop
      2 //
      3 // Robocop is free software: you can redistribute it and/or modify
      4 // it under the terms of the GNU General Public License as published by
      5 // the Free Software Foundation, either version 3 of the License, or
      6 // (at your option) any later version.
      7 //
      8 // Robocop is distributed in the hope that it will be useful,
      9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
     11 // GNU General Public License for more details.
     12 //
     13 // You should have received a copy of the GNU General Public License
     14 // along with this program. If not, see <https://www.gnu.org/licenses/>.
     15 //
     16 // Copyright (C) 2025 Taler Systems SA
     17 
     18 use serde_json::{Map, Value};
     19 use std::collections::HashMap;
     20 use std::env;
     21 use std::fs;
     22 use std::io::{self, BufRead, BufReader, Write};
     23 use std::process;
     24 
     25 const VERSION: &str = "1.0.0";
     26 
     27 fn print_version() {
     28     println!("robocop {}", VERSION);
     29 }
     30 
     31 fn print_help() {
     32     println!("Usage: robocop [OPTIONS] <sanctionlist>");
     33     println!();
     34     println!("Arguments:");
     35     println!("  <sanctionlist>    Sanction list in JSON to load");
     36     println!();
     37     println!("Options:");
     38     println!("  -h, --help     Show this help message and exit");
     39     println!("  -v, --version  Show version information and exit");
     40 }
     41 
     42 fn print_usage() {
     43     eprintln!("Usage: robocop [OPTIONS] <sanctionlist>");
     44     eprintln!("Try 'robocop --help' for more information.");
     45 }
     46 
     47 fn parse_args() -> Option<String> {
     48     let args: Vec<String> = env::args().collect();
     49 
     50     match args.len() {
     51         1 => {
     52             // No arguments provided
     53             print_usage();
     54             process::exit(1);
     55         }
     56         2 => {
     57             // Single argument
     58             match args[1].as_str() {
     59                 "-h" | "--help" => {
     60                     print_help();
     61                     process::exit(0);
     62                 }
     63                 "-v" | "--version" => {
     64                     print_version();
     65                     process::exit(0);
     66                 }
     67                 filename => {
     68                     // Return the filename to continue processing
     69                     Some(filename.to_string())
     70                 }
     71             }
     72         }
     73         _ => {
     74             // Too many arguments
     75             print_usage();
     76             process::exit(1);
     77         }
     78     }
     79 }
     80 
     81 #[derive(Debug, Clone)]
     82 struct Matching {
     83     // Final states with their associated values and costs
     84     final_states: HashMap<usize, (String, usize)>,
     85     max_state: usize,
     86 }
     87 
     88 impl Matching {
     89     fn new() -> Self {
     90         Self {
     91             final_states: HashMap::new(),
     92             max_state: 0,
     93         }
     94     }
     95 
     96     fn add_string(&mut self, s: &str) {
     97         let chars: Vec<char> = s.chars().collect();
     98         let mut state_ids = Vec::new();
     99 
    100         // Pre-allocate state IDs to avoid multiple mutable borrows
    101         for _ in 0..=chars.len() {
    102             let state_id = self.max_state;
    103             self.max_state += 1;
    104             state_ids.push(state_id);
    105         }
    106 
    107         // Add final states
    108         for (i, &state_id) in state_ids.iter().enumerate() {
    109             self.final_states
    110                 .insert(state_id, (s.to_string(), chars.len() - i));
    111         }
    112     }
    113 
    114     fn find_best_match(&self, input: &str) -> Option<(String, f64)> {
    115         let mut best_match = None;
    116         let mut best_score = 0.0;
    117 
    118         for (candidate, _) in self.final_states.values() {
    119             let distance = levenshtein_distance(input, candidate);
    120             let max_len = input.len().max(candidate.len());
    121             let score = if max_len == 0 {
    122                 1.0
    123             } else {
    124                 1.0 - (distance as f64 / max_len as f64)
    125             };
    126 
    127             if score > best_score {
    128                 best_score = score;
    129                 best_match = Some((candidate.clone(), score));
    130             }
    131         }
    132 
    133         best_match
    134     }
    135 }
    136 
    137 // Record structure for matching
    138 #[derive(Debug, Clone)]
    139 struct Record {
    140     ssid: String,
    141     fields: HashMap<String, Matching>,
    142 }
    143 
    144 impl Record {
    145     fn new(ssid: String) -> Self {
    146         Self {
    147             ssid,
    148             fields: HashMap::new(),
    149         }
    150     }
    151 
    152     fn add_field_values(&mut self, key: &str, values: &[String]) {
    153         let mut fsm = Matching::new();
    154         for value in values {
    155             fsm.add_string(value);
    156         }
    157         self.fields.insert(key.to_string(), fsm);
    158     }
    159 }
    160 
    161 // Matching engine
    162 struct MatchingEngine {
    163     records: Vec<Record>,
    164 }
    165 
    166 impl MatchingEngine {
    167     fn new() -> Self {
    168         Self {
    169             records: Vec::new(),
    170         }
    171     }
    172 
    173     fn load_from_json(&mut self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
    174         let content = fs::read_to_string(filename)?;
    175         let json_array: Vec<Value> = serde_json::from_str(&content)?;
    176 
    177         for (idx, item) in json_array.iter().enumerate() {
    178             if let Value::Object(obj) = item {
    179                 let ssid = obj
    180                     .get("ssid")
    181                     .and_then(|v| v.as_str())
    182                     .unwrap_or(&format!("record_{}", idx))
    183                     .to_string();
    184 
    185                 let mut record = Record::new(ssid);
    186 
    187                 for (key, value) in obj {
    188                     if key == "ssid" {
    189                         continue;
    190                     }
    191 
    192                     // Only process arrays
    193                     if let Value::Array(arr) = value {
    194                         let string_values: Vec<String> = arr
    195                             .iter()
    196                             .filter_map(|v| v.as_str().map(|s| s.to_string()))
    197                             .collect();
    198 
    199                         if !string_values.is_empty() {
    200                             record.add_field_values(key, &string_values);
    201                         }
    202                     }
    203                 }
    204 
    205                 self.records.push(record);
    206             }
    207         }
    208 
    209         Ok(())
    210     }
    211 
    212     fn find_best_match(&self, input: &Map<String, Value>) -> (f64, f64, String) {
    213         let mut best_overall_score = 0.0;
    214         let mut best_ssid = String::new();
    215         let mut best_avg_score = 0.0;
    216         let mut best_confidence = 0;
    217         let mut max_fields = 0;
    218 
    219         for record in &self.records {
    220             let mut total_score = 0.0;
    221             let mut matching_fields = 0;
    222             let total_fields = record.fields.len();
    223 
    224             for (key, input_value) in input {
    225                 if let Some(input_str) = input_value.as_str()
    226                     && let Some(fsm) = record.fields.get(key)
    227                     && let Some((_, score)) = fsm.find_best_match(input_str)
    228                 {
    229                     total_score += score;
    230                     matching_fields += 1;
    231                 }
    232             }
    233             max_fields = max_fields.max(total_fields);
    234             if total_fields > 0 && total_score > best_overall_score {
    235                 best_overall_score = total_score;
    236                 best_avg_score = total_score / matching_fields as f64;
    237                 best_confidence = matching_fields;
    238                 best_ssid = record.ssid.clone();
    239             }
    240         }
    241 
    242         (
    243             best_avg_score,
    244             best_confidence as f64 / max_fields as f64,
    245             best_ssid,
    246         )
    247     }
    248 }
    249 
    250 // Levenshtein distance implementation
    251 fn levenshtein_distance(s1: &str, s2: &str) -> usize {
    252     let chars1: Vec<char> = s1.chars().collect();
    253     let chars2: Vec<char> = s2.chars().collect();
    254     let len1 = chars1.len();
    255     let len2 = chars2.len();
    256 
    257     let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
    258 
    259     // Initialize first row and column
    260     for i in 0..=len1 {
    261         matrix[i][0] = i;
    262     }
    263     for j in 0..=len2 {
    264         matrix[0][j] = j;
    265     }
    266 
    267     // Fill the matrix
    268     for i in 1..=len1 {
    269         for j in 1..=len2 {
    270             let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
    271             matrix[i][j] = std::cmp::min(
    272                 std::cmp::min(
    273                     matrix[i - 1][j] + 1, // deletion
    274                     matrix[i][j - 1] + 1, // insertion
    275                 ),
    276                 matrix[i - 1][j - 1] + cost, // substitution
    277             );
    278         }
    279     }
    280 
    281     matrix[len1][len2]
    282 }
    283 
    284 fn main() -> Result<(), Box<dyn std::error::Error>> {
    285     let filename = parse_args().unwrap();
    286 
    287     // Load and pre-process the JSON database
    288     let mut engine = MatchingEngine::new();
    289     engine.load_from_json(&filename)?;
    290 
    291     // Read JSON objects from stdin
    292     let stdin = io::stdin();
    293     let reader = BufReader::new(stdin);
    294 
    295     for line in reader.lines() {
    296         let line = line?;
    297         if line.trim().is_empty() {
    298             eprintln!("ERROR: empty input line");
    299             std::process::exit(1);
    300         }
    301 
    302         match serde_json::from_str::<Value>(&line) {
    303             Ok(Value::Object(obj)) => {
    304                 eprintln!("INFO: robocop received input: {}", line);
    305                 let (quality, confidence, ssid) = engine.find_best_match(&obj);
    306                 println!("{:.6} {:.6} {}", quality, confidence, ssid);
    307                 // Not 100% clear if flush is needed here, but safer.
    308                 io::stdout().flush().unwrap();
    309             }
    310             Ok(_) => {
    311                 eprintln!("ERROR: non-object JSON received: {}", line);
    312                 std::process::exit(1);
    313             }
    314             Err(e) => {
    315                 eprintln!("ERROR: Failed to parse JSON: {} - {}", line, e);
    316                 std::process::exit(1);
    317             }
    318         }
    319     }
    320 
    321     Ok(())
    322 }