Leveraging Rust async for inversion of program control flow

Table of Contents

Intro

While working on some of my projects, I came upon an interesting problem:

How to connect a push-based and a pull-based stream component?

As an example, imagine a lexer and a parser. The parser is designed to pull tokens from the lexer sequentially and process them.

// parser
let mut lexer = SomeLexer {};

let token1 = lexer.next();
// do something with the token
let token2 = lexer.next();
// etc

Now if the lexer supports that style of operation natively, everything is good.

However, if the lexer expects to push the data it produces somewhere else we have a push-pull mismatch.

// lexer
self.push(Token::OpenBrace);
self.push(Token::CloseBrace);

To convert the lexer to support the iterator-style interface we would need to save its local state in some data structure in such a way that we can pause execution, switch control to the consumer and resume the execution of the lexer.

This concept is known as a coroutine and many programming languages offer support for it. An example would be Python with its generator functions. While Rust does not have proper coroutine support at the time of writing, it does have asynchronous functions. These work by storing the state required to resume execution from a specific point in time (an await point).

Basic Setup

So how do we adapt our lexer to that style?

First we introduce a one-item buffer to enable a handover between both pieces of code.

pub struct ImpedanceMatcher<T> {
    data: Cell<Option<T>>,
}

Then we add some operations:

impl<T> ImpedanceMatcher<T> {
    pub fn new() -> Self {
        Self {
            data: Cell::new(None),
        }
    }
    fn next(&self) -> Option<T> {
        self.data.replace(None)
    }
}

Please note, that because we use Cell here, the entire thing becomes !Send (only usable on the thread that created it). As a result of this trade-off we gain the ability to modify the content of the data field through shared references, which we will need to pass it to both the producer and consumer.

With this at hand, we can now trivially implement an IteratorAdapter to this struct by wrapping a shared reference:

pub struct IteratorAdapter<'a, T> {
    matcher: &'a ImpedanceMatcher<T>,
}
impl<T> Iterator for IteratorAdapter<'_, T> {
    type Item = T;
    fn next(&mut self) -> Option<Self::Item> {
        self.inner.next();
    }
}

Let’s try out what we have so far (in a test function):

#[cfg(test)]
#[test]
fn run_iterator() {
    let matcher = ImpedanceMatcher::<u32>::new();
    let iter = IteratorAdapter { matcher: &matcher };
    assert_eq!(None, iter.next());
}

When executing the code above the newly constructed iterator is empty because the buffer in the Cell is empty, and we never push data to it.

The Producer

Now we introduce our data production function:

async fn generate_outputs(matcher: &ImpedanceMatcher<u32>) {
    println!("gen_outputs");
    matcher.push(0).await;
    println!("after pushing 0");
    matcher.push(1).await;
    println!("after pushing 1");
    matcher.push(2).await;
}

To support this, we need to add a push-function to the matcher struct as follows:

impl<T> ImpedanceMatcher<T> {
    pub fn push(&self, item: T) -> PushFuture {
            self.data.set(Some(item));
            PushFuture { polled: false }
    }
}

struct PushFuture {
    polled: bool
}

Now let us implement the Future trait for our PushFuture type:

impl Future for PushFuture {
    type Output = ();
    fn poll(
        self: std::pin::Pin<&mut Self>,
        _ctx: &mut core::task::Context<'_>,
    ) {
        Poll::Ready(())
    }
}

Futures in Rust do not do anything until they are awaited, which the compiler translates into a call to the poll method of the Future trait. So, to have a type be a future we need to implement that method.

In our case, we are not planning to perform any truly asynchronous operations but only want to use the async support to suspend and resume execution. As such, for our first attempt the future is just always completed.

The world’s most basic async executor

Let’s also add a run method to our Impedancematcher. This method takes the matcher by reference as well as a future that will generate our outputs and a closure that will consume the generated iterator.

impl<T> ImpedanceMatcher<T> {
    pub fn run<TFut, TRes>(
        &self,
        generator: TFut,
        consumer: impl FnOnce(IteratorAdapter<'_, T>) -> TRes,
    ) -> TRes {
        let iter = IteratorAdapter {
            matcher: &self,
        };
        consumer(iter)
    }
}

Now we can write a test function that both produces and consumes some proper data:

#[cfg(test)]
#[test]
fn run_full_test() {
    async fn generate_outputs(matcher: &ImpedanceMatcher<u32>) {
        println!("gen_outputs");
        matcher.push(0).await;
        println!("after pushing 0");
        matcher.push(1).await;
        println!("after pushing 1");
        matcher.push(2).await;
    }

    let matcher = ImpedanceMatcher::new();
    let res = matcher.run(generate_outputs(&matcher), |iter| {
        let mut res = Vec::new();
        for item in iter {
            println!("received {item}");
            res.push(item);
        }
        res
    });

    assert_eq!([0, 1, 2].as_ref(), &res);
}

If we try to execute it though, it will fail:

thread 'run_example' panicked at lexer_impedance_matcher/src/lib.rs:120:5:
assertion `left == right` failed
  left: [0, 1, 2]
 right: []
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
test run_example ... FAILED

We also note that the test does not produce any output. This is expected because, as already mentioned, Rust futures are not executed until they are polled for the first time, and the matcher buffer is initially empty.

So let’s change the adaptor to poll the future to generate the next item before taking the value out of the buffer. To do this we add a reference to the future to the adaptor: the context field is something required by the future trait to support registering interest in being polled in the future on completion of some asynchronous event. As its execution is completely controlled by the iterator, we can safely ignore it.

// note the additional type parameter for the future type
pub struct IteratorAdapter<'a, T, F> {
    matcher: &'a ImpedanceMatcher<T>,
    future: Pin<&'a mut F>, // added
    context: Context<'a>, // added
}
impl<T, F> Iterator for IteratorAdapter<'_, T, F>
where
    F: Future,
{
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        let pinned = pin!(&mut self.future); // added
        let _ = pinned.poll(&mut self.context); // added
        self.matcher.next()
    }
}

Since only pinned futures can be executed, we need to wrap our future reference in the Pin struct and also project the reference onto the pin we end up executing.

Rust offers two ways to pin a value:

  • On the heap using Box::pin
  • On the stack using the pin! macro.

In order to avoid the allocation and be able to run in a no_std environment we opt for stack pinning.

Of course, we need to adapt our run method to take into account these new fields when constructing the iterator:

impl<T> ImpedanceMatcher<T> {
    pub fn run<TFut, TRes>(
        &self,
        generator: TFut,
        consumer: impl FnOnce(IteratorAdapter<'_, T, TFut>) -> TRes,
    ) -> TRes {
        // infra
        let waker = build_dummy_waker(); // added
        let context = Context::from_waker(&waker); // added

        let future = pin!(generator); // added
        let iter = IteratorAdapter {
            matcher: &self,
            context, // added
            future, // added
        };
        consumer(iter)
    }
}

What is this Waker we need to construct the context, even if we never end up using it?

It is a wrapper around a table of function pointers that the async runtime library can use to track when exactly polling a certain task will allow it to make progress.

Since we never use it, we just create a function table that never does anything but panic:

fn build_dummy_waker() -> Waker {
    const VTABLE: RawWakerVTable = RawWakerVTable::new(
        |_ptr| unreachable!(), // clone function pointer
        |_ptr| unreachable!(), // wake function pointer
        |_ptr| unreachable!(), // wake_by_ref function pointer
        |_ptr| {}, // drop function pointer
    );
    unsafe {
        let raw = RawWaker::new(core::ptr::null(), &VTABLE);
        Waker::from_raw(raw)
    }
}

At this point, we actually have a kind of runtime that will take a future and repeatedly poll it until completion, even if it does not do anything smart.

However, our test will fail like this:

running 1 test
gen_outputs
after pushing 0
after pushing 1
received 2
thread 'run_example' panicked at lexer_impedance_matcher/src/lib.rs:101:64:
`async fn` resumed after completion
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
test run_example ... FAILED

By looking at the output we can see that the code the Rust compiler generates for the generate_output function pushes all available outputs into the buffer one after the other without ever yielding control to the calling code.

To fix this, we need to force our PushFuture to yield control exactly one time once it is executed and then resolve the next time it is polled. We accomplish this by using a boolean field in the struct that is set on the first time the poll method is executed:

impl Future for PushFuture {
    type Output = ();

    fn poll(
        self: core::pin::Pin<&mut Self>,
        _cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Self::Output> {
        if self.polled {
            core::task::Poll::Ready(())
        } else {
            self.get_mut().polled = true;
            Poll::Pending
        }
    }
}

This causes the async method to yield control after producing exactly one value that the calling code can then take out of the buffer and consume.

running 1 test
gen_outputs
received 0
after pushing 0
received 1
after pushing 1
received 2
test run_example ... ok

The full source is available here: https://gist.github.com/fegies/5ac7b0171b161d77f877941c7353f0b3.