From 3f348fa7c64d0f5c2d3c6886867dda617ef7956d Mon Sep 17 00:00:00 2001 From: Astro Date: Fri, 27 Aug 2021 21:45:53 +0200 Subject: [PATCH] import_osm: refactor --- import_osm/Cargo.lock | 2 - import_osm/Cargo.toml | 3 +- import_osm/src/main.rs | 280 ++++++++++++++++++++--------------------- 3 files changed, 142 insertions(+), 143 deletions(-) diff --git a/import_osm/Cargo.lock b/import_osm/Cargo.lock index b0877a8..d8f913f 100644 --- a/import_osm/Cargo.lock +++ b/import_osm/Cargo.lock @@ -488,8 +488,6 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "osm_pbf_iter" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e22dc554505976589d669702894e4ba5fc320ef232808f3f143adfc48e2f0a" dependencies = [ "byteorder", "flate2", diff --git a/import_osm/Cargo.toml b/import_osm/Cargo.toml index 7070589..eec3e2e 100644 --- a/import_osm/Cargo.toml +++ b/import_osm/Cargo.toml @@ -11,7 +11,8 @@ lto = true opt-level = 3 [dependencies] -osm_pbf_iter = "0.2" +#osm_pbf_iter = "0.2" +osm_pbf_iter = { path = "../../../programming/rust-osm-pbf-iter" } num_cpus = "1" postgres = { version = "0.19", features = ["with-geo-types-0_7", "with-serde_json-1"] } serde = { version = "1.0", features = ["derive"] } diff --git a/import_osm/src/main.rs b/import_osm/src/main.rs index c11546e..45b90fb 100644 --- a/import_osm/src/main.rs +++ b/import_osm/src/main.rs @@ -3,177 +3,177 @@ use std::env::args; use std::fs::File; use std::io::{Seek, SeekFrom, BufReader}; use std::time::Instant; -use std::sync::mpsc::{sync_channel, SyncSender, Receiver}; +use std::sync::mpsc::{sync_channel, Receiver}; use std::sync::Arc; use std::thread; use osm_pbf_iter::*; -fn phase1_worker(req_rx: Receiver, res_tx: SyncSender>) { - let mut res = HashMap::new(); - - loop { - let blob = match req_rx.recv() { - Ok(blob) => blob, - Err(_) => break, - }; - - let data = blob.into_data(); - let primitive_block = PrimitiveBlock::parse(&data); - for primitive in primitive_block.primitives() { - match primitive { - Primitive::Node(node) => { - res.insert(node.id as i64, (node.lon, node.lat)); - } - Primitive::Way(_) => {} - Primitive::Relation(_) => {} - } - } - } - - res_tx.send(res).unwrap(); +pub struct PrimSource { + req_rx: Receiver } -fn phase2_worker(req_rx: Receiver, res_tx: SyncSender<()>, node_coords: Arc>) { - const DB_URL: &str = "host=10.233.1.2 dbname=treeadvisor user=treeadvisor password=123"; +impl PrimSource { + pub fn recv_primitives R, R>(&self, f: F) -> Option { + self.req_rx.recv() + .ok() + .map(|blob| { + let data = blob.into_data(); + let primitive_block = PrimitiveBlock::parse(&data); + f(primitive_block.primitives()) + }) + } +} - let mut db = postgres::Client::connect(DB_URL, postgres::NoTls) - .expect("DB"); +fn process_osm R + 'static + Send + Clone, R: Send + 'static>(filename: &str, f: F) -> Vec { + let cpus = num_cpus::get(); + let mut worker_results = Vec::with_capacity(cpus); - loop { - let blob = match req_rx.recv() { - Ok(blob) => blob, - Err(_) => break, - }; + // start workers + let mut workers = Vec::with_capacity(cpus); + for _ in 0..cpus { + let (req_tx, req_rx) = sync_channel::(2); + let (res_tx, res_rx) = sync_channel::(1); + workers.push((req_tx, res_rx)); - let data = blob.into_data(); - let primitive_block = PrimitiveBlock::parse(&data); - let mut tx = db.transaction().unwrap(); - for primitive in primitive_block.primitives() { - match primitive { - Primitive::Node(_) => {} - Primitive::Way(way) => { - let tags: serde_json::Map = way.tags() - .map(|(k, v)| (k.to_string(), serde_json::Value::String(v.to_string()))) - .collect(); - let points = way.refs() - .filter_map(|id| node_coords.get(&id)) - .cloned() - .collect::>(); - tx.execute( - "INSERT INTO osm_ways (geo, id, attrs) VALUES ($1, $2, $3)", - &[&geo::LineString::from(points), &(way.id as i64), &serde_json::Value::Object(tags)] - ).unwrap(); - } - Primitive::Relation(_) => {} - } - } - tx.commit().unwrap(); + let f = f.clone(); + thread::spawn(move || { + let prim_src = PrimSource { req_rx }; + let result = f(prim_src); + res_tx.send(result).unwrap(); + }); } - res_tx.send(()).unwrap(); + // open file + println!("Open {}", filename); + let f = File::open(filename).unwrap(); + let mut reader = BlobReader::new(BufReader::new(f)); + let start = Instant::now(); + + // feed + let mut w = 0; + for blob in &mut reader { + let req_tx = &workers[w].0; + w = (w + 1) % cpus; + + req_tx.send(blob).unwrap(); + } + + // receive results + for (req_tx, res_rx) in workers.into_iter() { + drop(req_tx); + let worker_res = res_rx.recv().unwrap(); + worker_results.push(worker_res); + } + + // stats + let stop = Instant::now(); + let duration = stop.duration_since(start); + let duration = duration.as_secs() as f64 + (duration.subsec_nanos() as f64 / 1e9); + let mut f = reader.into_inner(); + match f.seek(SeekFrom::Current(0)) { + Ok(pos) => { + let rate = pos as f64 / 1024f64 / 1024f64 / duration; + println!("Processed {} MB in {:.2} seconds ({:.2} MB/s)", + pos / 1024 / 1024, duration, rate); + }, + Err(_) => (), + } + + worker_results } fn main() { - let cpus = num_cpus::get(); - let mut node_coords: HashMap = HashMap::new(); + // phase 1: nodes for arg in args().skip(1) { - let mut phase1_workers = Vec::with_capacity(cpus); - for _ in 0..cpus { - let (req_tx, req_rx) = sync_channel(2); - let (res_tx, res_rx) = sync_channel(0); - phase1_workers.push((req_tx, res_rx)); - thread::spawn(move || { - phase1_worker(req_rx, res_tx); - }); - } + let worker_res = process_osm(&arg, move |prim_src| { + let mut res = HashMap::new(); - println!("Phase1: open {}", arg); - let f = File::open(&arg).unwrap(); - let mut reader = BlobReader::new(BufReader::new(f)); - let start = Instant::now(); + while prim_src.recv_primitives(|iter| { + for primitive in iter { + match primitive { + Primitive::Node(node) => { + res.insert(node.id as i64, (node.lon, node.lat)); + } + Primitive::Way(_) => {} + Primitive::Relation(_) => {} + } + } - let mut w = 0; - for blob in &mut reader { - let req_tx = &phase1_workers[w].0; - w = (w + 1) % cpus; - - req_tx.send(blob).unwrap(); - } - - for (req_tx, res_rx) in phase1_workers.into_iter() { - drop(req_tx); - let mut worker_res = res_rx.recv().unwrap(); + true + }).unwrap_or(false) {} + res + }); + for mut res in worker_res { if node_coords.is_empty() { - node_coords = worker_res; + node_coords = res; } else { // merge - for (id, coords) in worker_res.drain() { + for (id, coords) in res.drain() { node_coords.insert(id, coords); } } } - - let stop = Instant::now(); - let duration = stop.duration_since(start); - let duration = duration.as_secs() as f64 + (duration.subsec_nanos() as f64 / 1e9); - let mut f = reader.into_inner(); - match f.seek(SeekFrom::Current(0)) { - Ok(pos) => { - let rate = pos as f64 / 1024f64 / 1024f64 / duration; - println!("Phase1: Processed {} MB in {:.2} seconds ({:.2} MB/s)", - pos / 1024 / 1024, duration, rate); - }, - Err(_) => (), - } - } println!("{} nodes", node_coords.len()); let node_coords = Arc::new(node_coords); + let mut way_coords: HashMap> = HashMap::new(); + // phase 2: ways for arg in args().skip(1) { - let mut phase2_workers = Vec::with_capacity(cpus); - for _ in 0..cpus { - let (req_tx, req_rx) = sync_channel(2); - let (res_tx, res_rx) = sync_channel(0); - phase2_workers.push((req_tx, res_rx)); - let node_coords = node_coords.clone(); - thread::spawn(move || { - phase2_worker(req_rx, res_tx, node_coords); - }); + let node_coords = node_coords.clone(); + let worker_res = process_osm(&arg, move |prim_src| { + const DB_URL: &str = "host=10.233.1.2 dbname=treeadvisor user=treeadvisor password=123"; + + let mut db = postgres::Client::connect(DB_URL, postgres::NoTls) + .expect("DB"); + + let mut res = HashMap::new(); + + let mut running = true; + while running { + running = prim_src.recv_primitives(|iter| { + let mut tx = db.transaction().unwrap(); + for primitive in iter { + match primitive { + Primitive::Node(_) => {} + Primitive::Way(way) => { + let tags: serde_json::Map = way.tags() + .map(|(k, v)| (k.to_string(), serde_json::Value::String(v.to_string()))) + .collect(); + let points = way.refs() + .filter_map(|id| node_coords.get(&id)) + .cloned() + .collect::>(); + tx.execute( + "INSERT INTO osm_ways (geo, id, attrs) VALUES ($1, $2, $3)", + &[&geo::LineString::from(points.clone()), &(way.id as i64), &serde_json::Value::Object(tags)] + ).unwrap(); + res.insert(way.id as i64, points); + } + Primitive::Relation(_) => {} + } + } + tx.commit().unwrap(); + + true + }).unwrap_or(false); + } + res + }); + for mut res in worker_res { + if way_coords.is_empty() { + way_coords = res; + } else { + // merge + for (id, coords) in res.drain() { + way_coords.insert(id, coords); + } + } } - - println!("Phase2: open {}", arg); - let f = File::open(&arg).unwrap(); - let mut reader = BlobReader::new(BufReader::new(f)); - let start = Instant::now(); - - let mut w = 0; - for blob in &mut reader { - let req_tx = &phase2_workers[w].0; - w = (w + 1) % cpus; - - req_tx.send(blob).unwrap(); - } - - for (req_tx, res_rx) in phase2_workers.into_iter() { - drop(req_tx); - let _worker_res = res_rx.recv().unwrap(); - } - - let stop = Instant::now(); - let duration = stop.duration_since(start); - let duration = duration.as_secs() as f64 + (duration.subsec_nanos() as f64 / 1e9); - let mut f = reader.into_inner(); - match f.seek(SeekFrom::Current(0)) { - Ok(pos) => { - let rate = pos as f64 / 1024f64 / 1024f64 / duration; - println!("Phase2: Processed {} MB in {:.2} seconds ({:.2} MB/s)", - pos / 1024 / 1024, duration, rate); - }, - Err(_) => (), - } - } + let way_coords = Arc::new(way_coords); + + // phase 3: rels (TODO) }