Synchronization in Rust
Overview
In this lecture we review how Mutex<T>
s work in Rust and how they interplay
with condition variables (Condvar
s). We understand that not everyone might be
comfortable with condition variables yet (they’re really weird and I (Armin)
struggled to understand them when I first learned about them!) so please ask
questions on Slack!
Link Explorer
The Link Explorer example showcases how multithreading can let us overlap wait times for http requests (i.e. “downloading a webpage”). This example starts at the “Multithreading” Wikipedia page and attempts to identify the link on that page that has the longest HTML body.
- In order to do this, we need to keep track of the longest page so far as well as the URL of that page.
- Since this data must be shared between threads, we need to wrap it in an
Arc<Mutex<T>>
. - This naturally begs the question “What’s the T?” (haha) Well we want our mutex
to protect both the URL as well as the length of the page associated with that
URL. There can only be one
T
so we must define astruct
(we could have gone with pair/tuple) to house both pieces of data. - This plays nicely with the monitor paradigm in concurrent programming in which we group related pieces of data together and protect that data with concurrency primitives (e.g. mutexes).
- Calling we need to
clone
ourArc<Mutex<Article>>>
before moving it into the thread (otherwise the next thread can’t have access to it). - When we call
lock
on our mutex, we need tounwrap
the result (in case something weird happened inside a thread that was holding onto that mutex) and we obtain aMutexGuard<Article>
. - The
MutexGuard<Article>
is a smart pointer that will let us access and modify the members of the article while ensuring that we have exclusive access to it. (this happens through the magic of theDeref
andDerefMut
traits). - Once the mutex guard goes out of scope, it releases the lock, thereby allowing another thread to acquire it.
- We must limit the number of threads running at once so that we don’t overwhelm the OS (try running the code yourself and making the batch size large!).
- The full code is shown below – feel free to run it for yourself! You’ll need
to add the following lines to your
Cargo.toml
:
[dependencies]
select = "0.4.3"
error-chain = "0.12.2"
reqwest = {version = "0.10.4", features = ["blocking"]}
- The HTTP request library is called “reqwest” and that’s pretty darn cute.
extern crate reqwest;
extern crate select;
#[macro_use]
extern crate error_chain;
use std::sync::{Arc, Mutex};
use std::{thread};
use select::document::Document;
use select::predicate::Name;
error_chain! {
foreign_links {
ReqError(reqwest::Error);
IoError(std::io::Error);
}
}
struct Article {
url: String,
len: usize,
}
const BATCH_SIZE: usize = 60;
// https://rust-lang-nursery.github.io/rust-cookbook/web/scraping.html
fn main() -> Result<()> {
let body = reqwest::blocking::get("https://en.wikipedia.org/wiki/Multithreading_(computer_architecture)")?
.text()?;
// Identify all linked wikipedia pages
let links = Document::from_read(body.as_bytes())?
.find(Name("a"))
.filter_map(|n| {
if let Some(link_str) = n.attr("href") {
if link_str.starts_with("/wiki/") {
Some(format!("{}/{}", "https://en.wikipedia.org",
&link_str[1..]))
} else {
None
}
} else {
None
}
}).collect::<Vec<String>>();
let longest_article = Arc::new(Mutex::new(Article {url: "".to_string(),
len: 0}));
let num_batches = links.len()/BATCH_SIZE;
println!("num_batches: {}", num_batches);
for batch_idx in 0..num_batches {
// println!("link: {}", link);
println!("batch_idx: {}", batch_idx);
let mut reqwesters = Vec::new();
let start = batch_idx * BATCH_SIZE;
let end = std::cmp::min((batch_idx + 1) * BATCH_SIZE, links.len());
for link in &links[start..end] {
let longest_article_clone = longest_article.clone();
let link_clone = link.clone();
reqwesters.push(thread::spawn(move || {
let body = reqwest::blocking::get(&link_clone).expect("").text().expect("");
let curr_len = body.len();
let mut longest_article_ref = longest_article_clone.lock().unwrap();
if curr_len > longest_article_ref.len {
longest_article_ref.len = curr_len;
longest_article_ref.url = link_clone.to_string();
}
}));
}
for handle in reqwesters {
handle.join().expect("Panic occurred in thread!");
}
//println!("page length: {}", curr_len);
}
let longest_article_ref = longest_article.lock().unwrap();
println!("{} was the longest article with length {}", longest_article_ref.url,
longest_article_ref.len);
Ok(())
}
SemaPlusPlus
- The batching approach earlier is not that flexible – we’d like to ideally spawn threads as necessary and limit that spawning process until the number of running threads drops below a certain threshold.
- We can impose that limit using a semaphore as you saw in CS110. You will have to do this yourself in assignment 5 (News Aggregator).
- We can do even better: why spin up new threads and let threads die when we can reuse existing ones? You will answer that question with your very own ThreadPool implementation in CS110 assignment 6.
- Now we will practice using condition variables in rust by implementing our
very own
SemaPlusPlus<T>
. - A
SemaPlusPlus<T>
is not just a regular old semaphore. Instead of incrementing a counter viasignal
, you can enqueue a message to be read by callingsend
. - Instead of decrementing a counter via
wait
, you can read a previously sent message by callingrecv
(and you will wait while the message queue is empty). - Like C++, a Rust CV requires you to acquire a lock first.
- In Rust, however, cv.wait_while takes as input a
MutexGuard
(the unwrapped output of calling lock) and returns this MutexGuard back once it’s done waiting (it must take ownership of theMutexGuard
). - It is also idiomatic to associate your
Mutex
with its correspondingCondvar
by putting them in a pair and wrapping this pair in anArc
. - We provide the full implementation below and you can find a playground link here
extern crate rand;
use std::sync::{Arc, Mutex, Condvar};
use std::{thread, time};
use std::collections::VecDeque;
use rand::Rng;
fn rand_sleep() {
let mut rng = rand::thread_rng();
thread::sleep(time::Duration::from_millis(rng.gen_range(0, 30)));
}
#[derive(Clone)]
pub struct SemaPlusPlus<T> {
queue_and_cv: Arc<(Mutex<VecDeque<T>>, Condvar)>,
}
impl<T> SemaPlusPlus<T> {
pub fn new() -> Self {
SemaPlusPlus {queue_and_cv: Arc::new((Mutex::new(VecDeque::new()),
Condvar::new()))}
}
// Enqueues -- Like semaphore.signal()
pub fn send(&self, message: T) {
let (queue_lock, cv) = &*self.queue_and_cv;
let mut queue = queue_lock.lock().unwrap();
let queue_was_empty = queue.is_empty();
queue.push_back(message);
if queue_was_empty {
cv.notify_all();
}
}
// Dequeues -- Like semaphore.wait()
pub fn recv(&self) -> T {
let (queue_lock, cv) = &*self.queue_and_cv;
// Wait until there is something to dequeue
let mut queue = cv.wait_while(queue_lock.lock().unwrap(), |queue| {
queue.is_empty()
}).unwrap();
// Should return Some(...) because we waited
queue.pop_front().unwrap()
}
}
const NUM_THREADS: usize = 12;
fn main() {
// Inspired by this example https://doc.rust-lang.org/stable/rust-by-example/std_misc/channels.html
let sem: SemaPlusPlus<String> = SemaPlusPlus::new();
let mut handles = Vec::new();
for i in 0..NUM_THREADS {
let sem_clone = sem.clone();
let handle = thread::spawn(move || {
rand_sleep();
sem_clone.send(format!("thread {} just finished!", i))
});
handles.push(handle);
}
for _ in 0..NUM_THREADS {
println!("{}", sem.recv())
}
for handle in handles {
handle.join().unwrap();
}
}