Program Synthesis is Possible in Rust
Program synthesis is the act of automatically constructing a program that fulfills a given specification. Perhaps you are interested in sketching a program, leaving parts of it incomplete, and then having a tool fill in those missing bits for you? Or perhaps you are a compiler, and you have some instruction sequence, but you want to find an equivalent-but-better instruction sequence? Program synthesizers promise to help you out in these situations!
I recently stumbled across Adrian Sampson’s Program Synthesis is Possible blog post. Adrian describes and implements minisynth, a toy program synthesizer that generates constants for holes in a template program when given a specification. What fun! As a way to learn more about program synthesis myself, I ported minisynth to Rust.
The Language
The input language is quite simple. The only type is the signed integer and our operations are addition, subtraction, multiplication, division, negation, left- and right-shift, and if-then-else conditionals.
Here is an example:
x * 10 + y
And here is conditional expression that evaluates to 27 if x
is non-zero, and 42
otherwise:
x ? 27 : 42
Abstract Syntax Tree
My representation of the AST uses an id-based arena and interns
identifier strings, which is a bit overkill for such a small program, but is a
pattern that has worked well for me in Rust. This pattern makes implementing
the petgraph
crate’s traits easy, which gets you all the graph
traversals and dominator algorithms, etc that a non-toy implementation will
eventually want.
The ast::Context
structure contains the arena of AST nodes and the interned
strings.
// src/ast.rs
use id_arena::{Arena, Id};
pub type StringId = Id<String>;
#[derive(Default)]
pub struct Context {
idents: Arena<String>,
already_interned: HashMap<String, StringId>,
nodes: Arena<Node>,
}
The ast::Node
definition is an enum
with a variant for each type of
expression in the language.
// src/ast.rs
pub type NodeId = Id<Node>;
pub enum Node {
Identifier(StringId),
Addition(NodeId, NodeId),
Subtraction(NodeId, NodeId),
Multiplication(NodeId, NodeId),
Division(NodeId, NodeId),
RightShift(NodeId, NodeId),
LeftShift(NodeId, NodeId),
Const(i64),
Negation(NodeId),
Conditional(NodeId, NodeId, NodeId),
}
The ast::Context
also has methods to allocate new ast::Node
s, get interned
strings, and access allocated nodes. These definitions are straightforward, so I
have elided them here. If you’re interested, you can look at the source on
GitHub.
Parsing
I use the wonderful lalrpop
parser generator to generate a parser
for the input language. The grammar and actions are given in full below:
// src/parser/grammar.lalrpop
use crate::ast;
use std::str::FromStr;
grammar(ctx: &mut ast::Context);
Integer: i64 = <s:r"[0-9]+"> => i64::from_str(s).unwrap();
Identifier: ast::NodeId = <s:r"[a-zA-Z][a-zA-Z0-9_]*"> => ctx.new_identifier(s);
Sum: ast::NodeId = {
<t:Term> => t,
<l:Sum> "+" <r:Term> => ctx.new_node(ast::Node::Addition(l, r)),
<l:Sum> "-" <r:Term> => ctx.new_node(ast::Node::Subtraction(l, r)),
};
Term: ast::NodeId = {
<i:Item> => i,
<l:Term> "*" <r:Item> => ctx.new_node(ast::Node::Multiplication(l, r)),
<l:Term> "/" <r:Item> => ctx.new_node(ast::Node::Division(l, r)),
<l:Term> ">>" <r:Item> => ctx.new_node(ast::Node::RightShift(l, r)),
<l:Term> "<<" <r:Item> => ctx.new_node(ast::Node::LeftShift(l, r)),
};
Item: ast::NodeId = {
<n:Integer> => ctx.new_node(ast::Node::Const(n)),
"-" <i:Item> => ctx.new_node(ast::Node::Negation(i)),
<i:Identifier> => i,
"(" <s:Start> ")" => s,
};
pub Start: ast::NodeId = {
<s:Sum> => s,
<condition:Sum> "?" <consequent:Sum> ":" <alternative:Sum> =>
ctx.new_node(ast::Node::Conditional(condition, consequent, alternative)),
};
Interpretation
To interpret expressions, we need a lookup
function that maps identifiers to
values, the ast::Context
that owns the AST nodes, and the id of the node we
are evaluating. We match on this node, and handle the following cases:
-
If the node represents a constant, we return that node’s associated constant value.
-
If the node represents an identifier, we get a reference to its interned identifier string from the context, and then query the
lookup
function for its value. -
If the node represents an operator, we recursively evaluate its operands and then apply the operator to the operands’ values. For division, we also check for divide-by-zero return an error.
Here is our initial interpreter function:
// src/eval.rs
pub fn eval<L>(
ctx: &mut ast::Context,
node: ast::NodeId,
lookup: &mut L
) -> Result<i64>
where
L: for<'a> FnMut(&'a str) -> Result<i64>,
{
match *ctx.node_ref(node) {
Node::Const(i) => Ok(i),
Node::Identifier(s) => {
let s = ctx.interned(s);
lookup(s)
}
Node::Addition(lhs, rhs) => {
let lhs = eval(ctx, lhs, lookup)?;
let rhs = eval(ctx, rhs, lookup)?;
Ok(lhs + rhs)
}
Node::Subtraction(lhs, rhs) => {
let lhs = eval(ctx, lhs, lookup)?;
let rhs = eval(ctx, rhs, lookup)?;
Ok(lhs - rhs)
}
Node::Multiplication(lhs, rhs) => {
let lhs = eval(ctx, lhs, lookup)?;
let rhs = eval(ctx, rhs, lookup)?;
Ok(lhs * rhs)
}
Node::Division(lhs, rhs) => {
let lhs = eval(ctx, lhs, lookup)?;
let rhs = eval(ctx, rhs, lookup)?;
if rhs == 0 {
bail!("divide by zero");
}
Ok(lhs / rhs)
}
Node::RightShift(lhs, rhs) => {
let lhs = eval(ctx, lhs, lookup)?;
let rhs = eval(ctx, rhs, lookup)?;
Ok(lhs >> rhs)
}
Node::LeftShift(lhs, rhs) => {
let lhs = eval(ctx, lhs, lookup)?;
let rhs = eval(ctx, rhs, lookup)?;
Ok(lhs << rhs)
}
Node::Negation(n) => {
let n = eval(ctx, n, lookup)?;
Ok(-n)
}
Node::Conditional(condition, consequent, alternative) => {
let condition = eval(ctx, condition, lookup)?;
let consequent = eval(ctx, consequent, lookup)?;
let alternative = eval(ctx, alternative, lookup)?;
Ok(if condition != 0 {
consequent
} else {
alternative
})
}
}
}
From Interpreter to Synthesizer
Our synthesizer will take a specification program and a template program. The template program may contain holes — in our system, these are variables that start with the letter “h”. Our goal is to synthesize constant values for these holes such that the template program implements the specification for all values of the non-hole variables.
Let’s consider the example from the original blog post:
// Specification:
x * 10
// Template:
(x << h1) + (x << h2)
Can we transform multiplication by ten into the sum of two constant left shifts?
If we can find constant values for h1
and h2
, then the answer is yes. Our
synthesizer should answer that either h1 = 1
and h2 = 3
, or that h1 = 3
and h2 = 1
.
To implement synthesis, we will walk the AST and generate constraints for the Z3 SMT solver that reflect the program’s semantics. We will do this for both the specification and the template, and then constrain the results of each of them to be equal to each other for every non-hole constant variable. Finally, we ask Z3 if it can find a solution to all of the constraints. Any solution that exists will provide definitions for the holes.
That is, we are asking the solver to find a solution for
∃h1h2…hm: ∀c1c2…cn: t = s
where t is the template’s constraints, s is the specification’s constraints, hi are holes in the template, and cj are constants in the template and specification programs.
Adrian’s original Python implementation of minisynth takes advantage of Python’s
dynamic nature and the Z3 Python library’s operator overloading to reuse the
interpreter for constraint generation without any changes to the interpreter
function. All you have to do is supply a lookup
function that returns Z3
bitvector variables instead of signed integers. A neat trick!
For our Rust implementation, we want to reuse the interpreter as well, but Rust
is statically typed and the z3
crate for Rust doesn’t implement operator
overloading. So we will factor out an interpret
function from our eval
function that is generic over some abstract interpreter.
An abstract interpreter must have an associated output type. For normal
evaluation, this will be an i64
, and for constraints generation it will be a
Z3 constraint. The abstract interpreter must have methods for evaluating each
operation of the input language, taking its operands as its output type,
applying the operation to them, and returning the results as its output type. It
must also provide a way to translate constants and identifiers into its output
type.
// src/abstract_interpret.rs
pub trait AbstractInterpret {
/// The output type of this interpreter.
type Output;
/// Create a constant output value.
fn constant(&mut self, c: i64) -> Self::Output;
/// `lhs + rhs`
fn add(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;
/// `lhs - rhs`
fn sub(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;
/// `lhs * rhs`
fn mul(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;
/// `lhs / rhs`. Fails on divide by zero.
fn div(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Result<Self::Output>;
/// `lhs >> rhs`
fn shr(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;
/// `lhs << rhs`
fn shl(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;
/// `-e`
fn neg(&mut self, e: &Self::Output) -> Self::Output;
/// Returns `1` if `lhs == rhs`, returns `0` otherwise.
fn eq(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;
/// Returns `1` if `lhs != rhs`, returns `0` otherwise.
fn neq(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;
/// Perform variable lookup for the identifier `var`.
fn lookup(&mut self, var: &str) -> Result<Self::Output>;
}
Next we make an interpretation function that takes an abstract interpreter and
uses it to interpret an expression of our input language. This looks almost the
same as our original eval
function, but there is one tricky bit: encoding
conditional’s semantics into interpreter methods without using Rust’s control
flow, which would be invisible to the solver. To do this, we multiply the
activated conditional arm by one and the deactivated conditional arm by zero and
then sum the products. An alternative approach would be to add a method for
interpreting conditionals directly to the AbstractIntepret
trait.
// src/abstract_interpret.rs
pub fn interpret<A>(
interpreter: &mut A,
ctx: &mut ast::Context,
node: ast::NodeId,
) -> Result<A::Output>
where
A: AbstractInterpret,
{
match *ctx.node_ref(node) {
Node::Const(i) => Ok(interpreter.constant(i)),
Node::Identifier(s) => {
let s = ctx.interned(s);
interpreter.lookup(s)
}
Node::Addition(lhs, rhs) => {
let lhs = interpret(interpreter, ctx, lhs)?;
let rhs = interpret(interpreter, ctx, rhs)?;
Ok(interpreter.add(&lhs, &rhs))
}
Node::Subtraction(lhs, rhs) => {
let lhs = interpret(interpreter, ctx, lhs)?;
let rhs = interpret(interpreter, ctx, rhs)?;
Ok(interpreter.sub(&lhs, &rhs))
}
Node::Multiplication(lhs, rhs) => {
let lhs = interpret(interpreter, ctx, lhs)?;
let rhs = interpret(interpreter, ctx, rhs)?;
Ok(interpreter.mul(&lhs, &rhs))
}
Node::Division(lhs, rhs) => {
let lhs = interpret(interpreter, ctx, lhs)?;
let rhs = interpret(interpreter, ctx, rhs)?;
interpreter.div(&lhs, &rhs)
}
Node::RightShift(lhs, rhs) => {
let lhs = interpret(interpreter, ctx, lhs)?;
let rhs = interpret(interpreter, ctx, rhs)?;
Ok(interpreter.shr(&lhs, &rhs))
}
Node::LeftShift(lhs, rhs) => {
let lhs = interpret(interpreter, ctx, lhs)?;
let rhs = interpret(interpreter, ctx, rhs)?;
Ok(interpreter.shl(&lhs, &rhs))
}
Node::Negation(e) => {
let e = interpret(interpreter, ctx, e)?;
Ok(interpreter.neg(&e))
}
Node::Conditional(condition, consequent, alternative) => {
let condition = interpret(interpreter, ctx, condition)?;
let consequent = interpret(interpreter, ctx, consequent)?;
let alternative = interpret(interpreter, ctx, alternative)?;
let zero = interpreter.constant(0);
let neq_zero = interpreter.neq(&condition, &zero);
let eq_zero = interpreter.eq(&condition, &zero);
let consequent = interpreter.mul(&neq_zero, &consequent);
let alternative = interpreter.mul(&eq_zero, &alternative);
Ok(interpreter.add(&consequent, &alternative))
}
}
}
We refactor eval
to apply an implementation of AbstractInterpret
that has an
i64
associated output type and directly evaluates expressions:
// src/eval.rs
struct Eval<'a> {
env: &'a HashMap<String, i64>,
}
impl<'a> AbstractInterpret for Eval<'a> {
type Output = i64;
fn constant(&mut self, c: i64) -> i64 { c }
fn lookup(&mut self, var: &str) -> Result<i64> {
self.env
.get(var)
.cloned()
.ok_or_else(|| format_err!("undefined variable: {}", var))
}
fn neg(&mut self, e: &i64) -> i64 { -e }
fn add(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs + rhs }
fn sub(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs - rhs }
fn mul(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs * rhs }
fn shr(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs >> rhs }
fn shl(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs << rhs }
fn div(&mut self, lhs: &i64, rhs: &i64) -> Result<i64> {
if *rhs == 0 {
bail!("divide by zero");
}
Ok(lhs / rhs)
}
fn eq(&mut self, lhs: &i64, rhs: &i64) -> i64 {
(lhs == rhs) as i64
}
fn neq(&mut self, lhs: &i64, rhs: &i64) -> i64 {
(lhs != rhs) as i64
}
}
pub fn eval(
ctx: &mut ast::Context,
node: NodeId,
env: &HashMap<String, i64>
) -> Result<i64> {
let eval = &mut Eval { env };
interpret(eval, ctx, node)
}
Finally, we are ready to start implementing synthesis!
First we create an implementation of AbstractInterpret
that builds up Z3
constraints. Its lookup
method keeps track of which variables have been used,
categorizes them by whether they are a hole or an unknown constant, and makes
sure that subsequent lookups of the same identifier return the same Z3
variable. All other methods map straightforwardly onto Z3 method calls.
// src/synthesize.rs
struct Synthesize<'a, 'ctx>
where
'ctx: 'a,
{
ctx: &'ctx z3::Context,
vars: &'a mut HashMap<String, z3::Ast<'ctx>>,
holes: &'a mut HashMap<z3::Ast<'ctx>, String>,
const_vars: &'a mut HashSet<z3::Ast<'ctx>>,
}
impl<'a, 'ctx> AbstractInterpret for Synthesize<'a, 'ctx> {
type Output = z3::Ast<'ctx>;
fn lookup(&mut self, var: &str) -> Result<z3::Ast<'ctx>> {
if !self.vars.contains_key(var) {
let c = self.ctx.fresh_bitvector_const(var, 64);
self.vars.insert(var.to_string(), c.clone());
if var.starts_with("h") {
self.holes.insert(c, var.to_string());
} else {
self.const_vars.insert(c);
}
}
Ok(self.vars[var].clone())
}
fn constant(&mut self, c: i64) -> z3::Ast<'ctx> {
z3::Ast::bitvector_from_i64(self.ctx, c as i64, 64)
}
fn add(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
lhs.bvadd(rhs)
}
fn sub(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
lhs.bvsub(rhs)
}
fn mul(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
lhs.bvmul(rhs)
}
fn div(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> Result<z3::Ast<'ctx>> {
Ok(lhs.bvsdiv(rhs))
}
fn shr(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
lhs.bvlshr(&rhs)
}
fn shl(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
lhs.bvshl(&rhs)
}
fn neg(&mut self, e: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
e.bvneg()
}
fn eq(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
lhs._eq(rhs).ite(&self.constant(1), &self.constant(0))
}
fn neq(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
lhs._eq(rhs).not().ite(&self.constant(1), &self.constant(0))
}
}
Our synthesis function will take the specification program and the template
program, and then use the Synthesize
abstract interpreter to generate
constraints for each of them.
// src/synthesize.rs
pub fn synthesize<'a>(
z3_ctx: &'a z3::Context,
ast_ctx: &mut ast::Context,
specification: NodeId,
template: NodeId,
) -> Result<HashMap<String, i64>> {
let mut vars = HashMap::new();
let mut holes = HashMap::new();
let mut const_vars = HashSet::new();
let synth = &mut Synthesize {
ctx: z3_ctx,
vars: &mut vars,
holes: &mut holes,
const_vars: &mut const_vars,
};
let specification = interpret(synth, ast_ctx, specification)?;
if !synth.holes.is_empty() {
bail!("the specification cannot have any holes!");
}
let template = interpret(synth, ast_ctx, template)?;
// ...
}
Next, we extract the constant variables and create our goal, which is a for-all constraint. The template must be equal to the specification for all possible values these constants could take.
let const_vars: Vec<_> = const_vars.iter().collect();
let templ_eq_spec = specification._eq(&template);
let goal = if const_vars.is_empty() {
templ_eq_spec
} else {
z3::Ast::forall_const(&const_vars, &templ_eq_spec)
};
Now that we have constructed our goal, we ask Z3 to solve it. If it can find an answer, we extract the values its assigned to each of the holes and return the results as a hash map.
let solver = z3::Solver::new(z3_ctx);
solver.assert(&goal);
if solver.check() {
let model = solver.get_model();
let mut results = HashMap::new();
for (hole, name) in holes {
results.insert(name, model.eval(&hole).unwrap().as_i64().unwrap());
}
Ok(results)
} else {
bail!("no solution")
}
And now we have a synthesizer!
When given
// Specification:
x * 10
// Template:
(x << h1) + (x << h2)
our synthesis gives the answer
{
"h1": 1,
"h2": 3,
}
And when given
// Specification:
x * 9
// Template:
x << (hb1 ? x : hn1) + (hb2 ? x : hn2)
it gives us the answer
{
"hb1": 0,
"hb2": 1,
"hn1": 3,
"hn2": 0,
}
Conclusion
This was quite fun!
Thanks to Adrian Sampson for writing the original blog post and minisynth Python implementation.
If you would like to learn more, here are a few resources:
- Program Synthesis by Gulwani et al. An overview of the field of program synthesis.
- Synthesizing Constants by John Regehr. A blog post about synthesizing constants in practice, informed by the experience of working on Souper. Lots of links to more papers to read.
- Rosette is a Racket-based language by Emina Torlak that leverages the similarity between interpreters and constraint generation for synthesizers to make developing synthesis and verification tooling easy and fun.