Synchronization in Rust

Overview

In this lecture we review how Mutex<T>s work in Rust and how they interplay with condition variables (Condvars). We understand that not everyone might be comfortable with condition variables yet (they’re really weird and I (Armin) struggled to understand them when I first learned about them!) so please ask questions on Slack!

The Link Explorer example showcases how multithreading can let us overlap wait times for http requests (i.e. “downloading a webpage”). This example starts at the “Multithreading” Wikipedia page and attempts to identify the link on that page that has the longest HTML body.

[dependencies]
select = "0.4.3"
error-chain = "0.12.2"
reqwest = {version = "0.10.4", features = ["blocking"]}
extern crate reqwest;
extern crate select;
#[macro_use]
extern crate error_chain;

use std::sync::{Arc, Mutex};
use std::{thread};
use select::document::Document;
use select::predicate::Name;

error_chain! {
   foreign_links {
       ReqError(reqwest::Error);
       IoError(std::io::Error);
   }
}

struct Article {
    url: String,
    len: usize,
}

const BATCH_SIZE: usize = 60;


// https://rust-lang-nursery.github.io/rust-cookbook/web/scraping.html
fn main() -> Result<()> {
    let body = reqwest::blocking::get("https://en.wikipedia.org/wiki/Multithreading_(computer_architecture)")?
    .text()?;
    // Identify all linked wikipedia pages
    let links = Document::from_read(body.as_bytes())?
        .find(Name("a"))
        .filter_map(|n| {
            if let Some(link_str) = n.attr("href") {
                if link_str.starts_with("/wiki/") {
                    Some(format!("{}/{}", "https://en.wikipedia.org",
                        &link_str[1..]))
                } else {
                    None
                }
            } else {
                None
            }
        }).collect::<Vec<String>>();
    let longest_article = Arc::new(Mutex::new(Article {url: "".to_string(),
        len: 0}));
    let num_batches = links.len()/BATCH_SIZE;
    println!("num_batches: {}", num_batches);
    for batch_idx in 0..num_batches {
        // println!("link: {}", link);
        println!("batch_idx: {}", batch_idx);
        let mut reqwesters = Vec::new();
        let start = batch_idx * BATCH_SIZE;
        let end = std::cmp::min((batch_idx + 1) * BATCH_SIZE, links.len());
        for link in &links[start..end] {
            let longest_article_clone = longest_article.clone();
            let link_clone = link.clone();
            reqwesters.push(thread::spawn(move || {
                let body = reqwest::blocking::get(&link_clone).expect("").text().expect("");
                let curr_len = body.len();
                let mut longest_article_ref = longest_article_clone.lock().unwrap();
                if curr_len > longest_article_ref.len {
                    longest_article_ref.len = curr_len;
                    longest_article_ref.url = link_clone.to_string();
                }
            }));
        }

        for handle in reqwesters {
            handle.join().expect("Panic occurred in thread!");
        }
        //println!("page length: {}", curr_len);
    }



    let longest_article_ref = longest_article.lock().unwrap();
    println!("{} was the longest article with length {}", longest_article_ref.url,
        longest_article_ref.len);
    Ok(())
}

SemaPlusPlus

extern crate rand;
use std::sync::{Arc, Mutex, Condvar};

use std::{thread, time};
use std::collections::VecDeque;
use rand::Rng;

fn rand_sleep() {
    let mut rng = rand::thread_rng();
    thread::sleep(time::Duration::from_millis(rng.gen_range(0, 30)));
}

#[derive(Clone)]
pub struct SemaPlusPlus<T> {
    queue_and_cv: Arc<(Mutex<VecDeque<T>>, Condvar)>,
}

impl<T> SemaPlusPlus<T> {
    pub fn new() -> Self {
        SemaPlusPlus {queue_and_cv: Arc::new((Mutex::new(VecDeque::new()),
            Condvar::new()))}
    }
    
    // Enqueues -- Like semaphore.signal()
    pub fn send(&self, message: T) {
        let (queue_lock, cv) = &*self.queue_and_cv;
        let mut queue = queue_lock.lock().unwrap();
        let queue_was_empty = queue.is_empty();
        queue.push_back(message);
        if queue_was_empty {
            cv.notify_all();
        }
    }
    
    // Dequeues -- Like semaphore.wait()
    pub fn recv(&self) -> T {
        let (queue_lock, cv) = &*self.queue_and_cv;
        // Wait until there is something to dequeue
        let mut queue = cv.wait_while(queue_lock.lock().unwrap(), |queue| {
            queue.is_empty()
        }).unwrap();
        // Should return Some(...) because we waited
        queue.pop_front().unwrap()
    }
}

const NUM_THREADS: usize = 12;
fn main() {
    // Inspired by this example https://doc.rust-lang.org/stable/rust-by-example/std_misc/channels.html
    let sem: SemaPlusPlus<String> = SemaPlusPlus::new();
    let mut handles = Vec::new();
    for i in 0..NUM_THREADS {
        let sem_clone = sem.clone();
        let handle = thread::spawn(move || {
            rand_sleep();
            sem_clone.send(format!("thread {} just finished!", i))
        });
        handles.push(handle);
    }
    for _ in 0..NUM_THREADS {
        println!("{}", sem.recv())
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
}