treeadvisor/import_osm/src/main.rs

180 lines
5.9 KiB
Rust

use std::collections::HashMap;
use std::env::args;
use std::fs::File;
use std::io::{Seek, SeekFrom, BufReader};
use std::time::Instant;
use std::sync::mpsc::{sync_channel, SyncSender, Receiver};
use std::sync::Arc;
use std::thread;
use osm_pbf_iter::*;
fn phase1_worker(req_rx: Receiver<Blob>, res_tx: SyncSender<HashMap<i64, (f64, f64)>>) {
let mut res = HashMap::new();
loop {
let blob = match req_rx.recv() {
Ok(blob) => blob,
Err(_) => break,
};
let data = blob.into_data();
let primitive_block = PrimitiveBlock::parse(&data);
for primitive in primitive_block.primitives() {
match primitive {
Primitive::Node(node) => {
res.insert(node.id as i64, (node.lon, node.lat));
}
Primitive::Way(_) => {}
Primitive::Relation(_) => {}
}
}
}
res_tx.send(res).unwrap();
}
fn phase2_worker(req_rx: Receiver<Blob>, res_tx: SyncSender<()>, node_coords: Arc<HashMap<i64, (f64, f64)>>) {
const DB_URL: &str = "host=10.233.1.2 dbname=treeadvisor user=treeadvisor password=123";
let mut db = postgres::Client::connect(DB_URL, postgres::NoTls)
.expect("DB");
loop {
let blob = match req_rx.recv() {
Ok(blob) => blob,
Err(_) => break,
};
let data = blob.into_data();
let primitive_block = PrimitiveBlock::parse(&data);
let mut tx = db.transaction().unwrap();
for primitive in primitive_block.primitives() {
match primitive {
Primitive::Node(_) => {}
Primitive::Way(way) => {
let tags: serde_json::Map<String, serde_json::Value> = way.tags()
.map(|(k, v)| (k.to_string(), serde_json::Value::String(v.to_string())))
.collect();
let points = way.refs()
.filter_map(|id| node_coords.get(&id))
.cloned()
.collect::<Vec<_>>();
tx.execute(
"INSERT INTO osm_ways (geo, id, attrs) VALUES ($1, $2, $3)",
&[&geo::LineString::from(points), &(way.id as i64), &serde_json::Value::Object(tags)]
).unwrap();
}
Primitive::Relation(_) => {}
}
}
tx.commit().unwrap();
}
res_tx.send(()).unwrap();
}
fn main() {
let cpus = num_cpus::get();
let mut node_coords: HashMap<i64, (f64, f64)> = HashMap::new();
for arg in args().skip(1) {
let mut phase1_workers = Vec::with_capacity(cpus);
for _ in 0..cpus {
let (req_tx, req_rx) = sync_channel(2);
let (res_tx, res_rx) = sync_channel(0);
phase1_workers.push((req_tx, res_rx));
thread::spawn(move || {
phase1_worker(req_rx, res_tx);
});
}
println!("Phase1: open {}", arg);
let f = File::open(&arg).unwrap();
let mut reader = BlobReader::new(BufReader::new(f));
let start = Instant::now();
let mut w = 0;
for blob in &mut reader {
let req_tx = &phase1_workers[w].0;
w = (w + 1) % cpus;
req_tx.send(blob).unwrap();
}
for (req_tx, res_rx) in phase1_workers.into_iter() {
drop(req_tx);
let mut worker_res = res_rx.recv().unwrap();
if node_coords.is_empty() {
node_coords = worker_res;
} else {
// merge
for (id, coords) in worker_res.drain() {
node_coords.insert(id, coords);
}
}
}
let stop = Instant::now();
let duration = stop.duration_since(start);
let duration = duration.as_secs() as f64 + (duration.subsec_nanos() as f64 / 1e9);
let mut f = reader.into_inner();
match f.seek(SeekFrom::Current(0)) {
Ok(pos) => {
let rate = pos as f64 / 1024f64 / 1024f64 / duration;
println!("Phase1: Processed {} MB in {:.2} seconds ({:.2} MB/s)",
pos / 1024 / 1024, duration, rate);
},
Err(_) => (),
}
}
println!("{} nodes", node_coords.len());
let node_coords = Arc::new(node_coords);
for arg in args().skip(1) {
let mut phase2_workers = Vec::with_capacity(cpus);
for _ in 0..cpus {
let (req_tx, req_rx) = sync_channel(2);
let (res_tx, res_rx) = sync_channel(0);
phase2_workers.push((req_tx, res_rx));
let node_coords = node_coords.clone();
thread::spawn(move || {
phase2_worker(req_rx, res_tx, node_coords);
});
}
println!("Phase2: open {}", arg);
let f = File::open(&arg).unwrap();
let mut reader = BlobReader::new(BufReader::new(f));
let start = Instant::now();
let mut w = 0;
for blob in &mut reader {
let req_tx = &phase2_workers[w].0;
w = (w + 1) % cpus;
req_tx.send(blob).unwrap();
}
for (req_tx, res_rx) in phase2_workers.into_iter() {
drop(req_tx);
let _worker_res = res_rx.recv().unwrap();
}
let stop = Instant::now();
let duration = stop.duration_since(start);
let duration = duration.as_secs() as f64 + (duration.subsec_nanos() as f64 / 1e9);
let mut f = reader.into_inner();
match f.seek(SeekFrom::Current(0)) {
Ok(pos) => {
let rate = pos as f64 / 1024f64 / 1024f64 / duration;
println!("Phase2: Processed {} MB in {:.2} seconds ({:.2} MB/s)",
pos / 1024 / 1024, duration, rate);
},
Err(_) => (),
}
}
}