Kolmogorov-Smirnov metric-wise similarity

This commit is contained in:
eugen.betke 2020-09-03 12:20:21 +02:00
parent cd307c98da
commit 0545f89fc2
10 changed files with 52 additions and 28 deletions

Binary file not shown.

View File

@ -40,7 +40,11 @@ pub fn test<T: Ord + Clone>(xs: &[T], ys: &[T], confidence: f64) -> Result<TestR
assert!(0.0 < confidence && confidence < 1.0); assert!(0.0 < confidence && confidence < 1.0);
// Only supports samples of size > 7. // Only supports samples of size > 7.
assert!(xs.len() > 7 && ys.len() > 7); //assert!(xs.len() > 7 && ys.len() > 7);
if xs.len() > 7 && ys.len() > 7 {
return Err(String::from("Assertion violated: xs.len() > 7 && ys.len() > 7"));
}
let statistic = calculate_statistic(xs, ys); let statistic = calculate_statistic(xs, ys);
let critical_value = calculate_critical_value(xs.len(), ys.len(), confidence)?; let critical_value = calculate_critical_value(xs.len(), ys.len(), confidence)?;
@ -198,7 +202,11 @@ fn calculate_statistic<T: Ord + Clone>(xs: &[T], ys: &[T]) -> f64 {
/// evidence exceeds the confidence level required. /// evidence exceeds the confidence level required.
fn calculate_reject_probability(statistic: f64, n1: usize, n2: usize) -> Result<f64, String> { fn calculate_reject_probability(statistic: f64, n1: usize, n2: usize) -> Result<f64, String> {
// Only supports samples of size > 7. // Only supports samples of size > 7.
assert!(n1 > 7 && n2 > 7); // assert!(n1 > 7 && n2 > 7);
if n1 > 7 && n2 > 7 {
return Err(String::from("Assertion violated: n1 > 7 && n2 > 7"));
}
let n1 = n1 as f64; let n1 = n1 as f64;
let n2 = n2 as f64; let n2 = n2 as f64;
@ -234,7 +242,10 @@ pub fn calculate_critical_value(n1: usize, n2: usize, confidence: f64) -> Result
assert!(0.0 < confidence && confidence < 1.0); assert!(0.0 < confidence && confidence < 1.0);
// Only supports samples of size > 7. // Only supports samples of size > 7.
assert!(n1 > 7 && n2 > 7); //assert!(n1 > 7 && n2 > 7);
if n1 > 7 && n2 > 7 {
return Err(String::from("Assertion violated: n1 > 7 && n2 > 7"));
}
// The test statistic is between zero and one so can binary search quickly // The test statistic is between zero and one so can binary search quickly
// for the critical value. // for the critical value.

View File

@ -40,17 +40,30 @@ pub struct SimilarityRow {
pub jobid: u32, pub jobid: u32,
pub alg_id: u32, pub alg_id: u32,
pub alg_name: String, pub alg_name: String,
pub similarity: f32 pub similarity: f32,
pub status: String,
} }
//#[derive(Debug, Serialize)]
//pub struct ProgressRow {
// jobid: u32,
// alg_id: u32,
// alg_name: String,
// delta: i64,
//}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct ProgressRow { pub struct ProgressRow {
jobid: u32, iteration: u32,
alg_id: u32, alg_id: u32,
alg_name: String, alg_name: String,
delta: i64, jobs_done: usize,
jobs_total: usize,
elapsed: f64,
delta: f64,
} }
pub fn convert_to_coding(coding: String) -> Vec<Score> { pub fn convert_to_coding(coding: String) -> Vec<Score> {
let split = coding.split(":"); let split = coding.split(":");
let vec: Vec<Score> = split let vec: Vec<Score> = split
@ -112,31 +125,35 @@ fn run(dataset_fn: String, jobid: Jobid, similarities_fn: String, progress_fn: S
let mut avail_codings: Vec<(u32, &JobCoding)>; let mut avail_codings: Vec<(u32, &JobCoding)>;
avail_codings = q_codings.iter().map(|(k, v)| (*k, v)).collect(); avail_codings = q_codings.iter().map(|(k, v)| (*k, v)).collect();
let mut similarities: Vec<(Jobid, Similarity)> = Vec::new(); let mut similarities: Vec<(Jobid, Similarity, bool)> = Vec::new();
let log_file = File::create(&log_fn).expect("Unable to open"); let log_file = File::create(&log_fn).expect("Unable to open");
let mut log_file = LineWriter::new(log_file); let mut log_file = LineWriter::new(log_file);
let probe = q_codings[&jobid].clone(); let probe = q_codings[&jobid].clone();
let mut start = chrono::Utc::now(); let mut start_chunk = chrono::Utc::now();
let start = start_chunk;
while let Some((jobid, q_coding)) = avail_codings.pop() { while let Some((jobid, q_coding)) = avail_codings.pop() {
if (counter % 10_000) == 0 { if (counter % 10_000) == 0 {
let stop = chrono::Utc::now(); let stop_chunk = chrono::Utc::now();
let progress_row = ProgressRow { let progress_row = ProgressRow {
jobid: jobid, iteration: 0,
alg_id: alg_id, alg_id: alg_id,
alg_name: String::from(alg_name), alg_name: String::from(alg_name),
delta: ((stop - start).num_nanoseconds().unwrap()) jobs_done: counter,
jobs_total: q_codings.len(),
elapsed: (((stop_chunk - start).num_milliseconds() as f64) / 1000.0),
delta: (((stop_chunk - start_chunk).num_milliseconds() as f64) / 1000.0),
}; };
wtr_progress.serialize(progress_row).unwrap(); wtr_progress.serialize(progress_row).unwrap();
start = stop; start_chunk = stop_chunk;
} }
//println!("Processing {:?}", jobid); //println!("Processing {:?}", jobid);
//let similarity = ks_similarity(q_coding, &probe); //let similarity = ks_similarity(q_coding, &probe);
let mut metric_similarities = vec![]; let mut metric_similarities = vec![];
let mut err = false;
let confidence = 0.95; let confidence = 0.95;
for metric_codings in q_coding.iter().zip(&probe) { for metric_codings in q_coding.iter().zip(&probe) {
@ -145,36 +162,32 @@ fn run(dataset_fn: String, jobid: Jobid, similarities_fn: String, progress_fn: S
(1.0 - sim.reject_probability) as Similarity (1.0 - sim.reject_probability) as Similarity
} }
Err(e) => { Err(e) => {
let message = format!("jobid failed {:?}, because \" {:?}\"\n", jobid, e); err = true;
let message = format!("jobid failed {:?}, because {:?}\n", jobid, e);
log_file.write_all(message.as_bytes()).unwrap(); log_file.write_all(message.as_bytes()).unwrap();
1.0 0.0
} }
}; };
metric_similarities.push(metric_similarity); metric_similarities.push(metric_similarity);
} }
let similarity = metric_similarities.iter().sum::<f32>() / (metric_similarities.len() as f32); let similarity = metric_similarities.iter().sum::<f32>() / (metric_similarities.len() as f32);
//let similarity = match ks::test(q_coding, &probe, confidence) { similarities.push((jobid, similarity, err));
// Ok(sim) => {
// (1.0 - sim.reject_probability) as Similarity,
// }
// Err(e) => {
// let message = format!("jobid failed {:?}, because \" {:?}\"\n", jobid, e);
// log_file.write_all(message.as_bytes()).unwrap();
// 1.0
// }
//};
similarities.push((jobid, similarity));
counter += 1; counter += 1;
} }
for (jobid, similarity) in similarities.iter() {
let mut status_map: HashMap<bool, String> = HashMap::new();
status_map.insert(false, String::from("ok"));
status_map.insert(true, String::from("failed"));
for (jobid, similarity, err) in similarities.iter() {
let similarity_row = SimilarityRow { let similarity_row = SimilarityRow {
jobid: *jobid, jobid: *jobid,
alg_id: alg_id, alg_id: alg_id,
alg_name: String::from(alg_name), alg_name: String::from(alg_name),
similarity: *similarity, similarity: *similarity,
status: status_map[err].clone(),
}; };
wtr_similarities.serialize(similarity_row).unwrap(); wtr_similarities.serialize(similarity_row).unwrap();
} }