Stupid tricks with Rust higher-order functions and "impl trait"
While attending CodeMash 2017, I had a realization about how an upcoming Rust feature could be used to make higher order functions nicer without the overhead of a heap allocation, and wanted to share this idea and see what other people thought.
Sometimes, you’d like to take a function with N
arguments and
hard-code some of the arguments while allowing the others to
vary. This is known as partial application. As of Rust 1.15, you
can do this by creating a closure that partially applies one or more
arguments:
fn add_3_numbers(a: u8, b: u8, c: u8) -> u8 {
a + b + c
}
fn main() {
assert_eq!(6, add_3_numbers(1, 2, 3));
let add_2_numbers_to_the_value_1 = |a, b| add_3_numbers(a, b, 1);
assert_eq!(6, add_2_numbers_to_the_value_1(2, 3));
}
One really cool thing is that the compiler “sees through” the closure and optimizes it away, just the same as if you had called the function directly.
Unfortunately, we cannot abstract the closure creation to a function that creates closures that partially apply with a number other than 1:
fn make_add_2_numbers(c: u8) -> ??? {
move |a, b| add_3_numbers(a, b, c)
}
The problem is that it’s not possible for the programmer to give a name to the type of a closure in stable Rust. The current workaround is to return a boxed trait object:
fn make_add_2_numbers(c: u8) -> Box<Fn(u8, u8) -> u8> {
Box::new(move |a, b| add_3_numbers(a, b, c))
}
fn main() {
let add_2_numbers_to_the_value_1 = make_add_2_numbers(1);
assert_eq!(6, add_2_numbers_to_the_value_1(2, 3));
}
In some cases (like this one), the optimizer can see through the heap allocation and it will optimize away the allocation 1. However, there’s no guarantee that this will occur.
In unstable Rust, there’s a feature that promises to help:
conservative_impl_trait
(RFC 1522). This feature allows the
programmer to declare that a function will return some type that
adheres to a trait, but without saying what the exact type is; the
compiler will “fill in” the exact type for us. This means that we
don’t need to be able to write the type of a closure!
This feature is frequently talked about when someone is looking
to return an iterator or a Future, both of
which frequently contain closures. However, Rust’s function traits
(Fn
, FnMut
, and FnOnce
) are fair game to participate in
conservative_impl_trait
on their own. We can write our example as:
#![feature(conservative_impl_trait)]
fn make_add_2_numbers(c: u8) -> impl Fn(u8, u8) -> u8 {
move |a, b| add_3_numbers(a, b, c)
}
fn main() {
let add_2_numbers_to_the_value_1 = make_add_2_numbers(1);
assert_eq!(6, add_2_numbers_to_the_value_1(2, 3));
}
This is nice, as we now know that there’s no extra heap allocation.
Higher-order functions
We can go one step further and create a higher-order function, a function that accepts a function as an argument and returns another function as the result. The traditional example is to add logging:
#![feature(conservative_impl_trait)]
fn add(a: u8, b: u8) -> u8 {
a + b
}
fn sub(a: u8, b: u8) -> u8 {
a - b
}
fn log<F>(f: F) -> impl Fn(u8, u8) -> u8
where F: Fn(u8, u8) -> u8,
{
move |a, b| {
println!("Calling with {}, {}", a, b);
let r = f(a, b);
println!("Result was {}", r);
r
}
}
fn main() {
let logging_add = log(add);
let logging_sub = log(sub);
logging_add(1, 2);
logging_sub(2, 1);
}
The log
function only knows that it will be given some type that can
be called with two u8
arguments and will return another u8
. This
is specified by the where
clause, and should be familiar if you’ve
done anything with closures in Rust.
Adding logging in this manner allows us to separate the concerns of adding, subtracting, and logging. We also get to reuse the logging code.
A practical example: parser combinators
Parser combinators allow you to build simple parsers and then combine them to build up parsers for more complicated grammars. As a simplified example, check out this parser that parses a string up to the next newline:
// On success, return what was parsed and where to start parsing from next.
// On failure, return where the parsing failed.
type Result<T> = ::std::result::Result<(T, usize), usize>;
fn parse_until_newline(s: &str, location: usize) -> Result<&str> {
let s = &s[location..];
match s.find("\n") {
Some(pos) => {
let head = &s[..pos];
Ok((head, location + pos + "\n".len()))
},
None => {
Err(location)
}
}
}
fn main() {
let input = "hello\nworld";
assert_eq!(Ok(("hello", 6)), parse_until_newline(input, 0));
assert_eq!(Err(6), parse_until_newline(input, 6));
}
Our parser is likely to need the ability to find other hard-coded
strings besides a newline. Instead of copy-and-pasting the
implementation, we can transform our parse_until_newline
function
into the higher-order parse_until
. This will create parser functions
that allow parsing until any user-supplied string.
#![feature(conservative_impl_trait)]
fn parse_until(terminator: &'static str) -> impl Fn(&str, usize) -> Result<&str> {
move |s, location| {
let s = &s[location..];
match s.find(terminator) {
Some(pos) => {
let head = &s[..pos];
Ok((head, location + pos + terminator.len()))
},
None => {
Err(location)
}
}
}
}
fn main() {
let input = "hello\nworld";
assert_eq!(Ok(("hello", 6)), parse_until("\n")(input, 0));
assert_eq!(Err(6), parse_until("\n")(input, 6));
}
These basic parsers work well until we need to be able to reuse our
parser in a slightly different context. For example, we might want to
be able to recover from a failed parse because the thing we are trying
to match against is optional. We can implement this by creating a
higher order function that accepts any type that implements
Fn(&str, usize)
and returns a new Fn(&str, usize)
that converts a
success into a Some
and a failure to a None
.
fn optional<F>(parser: F) -> impl Fn(&str, usize) -> Result<Option<&str>>
where F: Fn(&str, usize) -> Result<&str>
{
move |s, location| {
match parser(s, location) {
Ok((s, l)) => Ok((Some(s), l)),
Err(_) => Ok((None, location)),
}
}
}
fn main() {
let input = "hello\nworld";
assert_eq!(Ok((Some("hello"), 6)), optional(parse_until("\n"))(input, 0));
assert_eq!(Ok((None, 6)), optional(parse_until("\n"))(input, 6));
}
This can be extended to many other combinators, such as zero-or-more or one-or-more. Alternation can also be implemented, but it’s a little tricky to get a good API that doesn’t rely on trait objects and isn’t ugly.
I’m continuing to explore this implementation style of parser
combinators in Peresil, the parsing library that
underlies SXD, a Rust XML library. Since
conservative_impl_trait
requires nightly and SXD works on stable,
I’m testing the ideas out in a new
project, Strata. Among other things, Strata is an
up-and-coming parser of Rust code. If you are interested in any of
these topics, I’d love to hear from you!
1 In this example, the optimizer actually removes all of the code. It sees through the heap allocation, performs all the math, sees that the assertion can never fire, and then removes all the code as it’s unreachable. Optimizers are neat!