IR Manipulations and Code Generation
Basics
Besides the Symbolics.jl front-end, JuliaSimCompiler.jl also supports tracing using SymIR
directly. For example:
using JuliaSimCompiler
const IR = JuliaSimCompiler
function repeatf(f::F, x, n) where {F}
y = x
for _ in 1:n
y = f(y)
end
y
end
foo_once(x) = sin(x + 0.1 * x^3)
foo(x, n = 6) = repeatf(foo_once, x, n)
x = IR.Var(:x)
y = foo(x)
%1 = sin((x + (0.1 * (x ^ 3))))
%2 = sin((%1 + (0.1 * (%1 ^ 3))))
%3 = sin((%2 + (0.1 * (%2 ^ 3))))
%4 = sin((%3 + (0.1 * (%3 ^ 3))))
%5 = sin((%4 + (0.1 * (%4 ^ 3))))
%6 = sin((%5 + (0.1 * (%5 ^ 3))))
In the above example, we construct a SymIR
object x
with the name :x
and trace it through the function foo
. %1
, %2
, etc, are the names of the intermediate variables or SSA values. The default printing of SymIR
tries to be as concise as possible, i.e. it will inline all constants and SSA values that are used at most once. To see the unabbreviated version, use show
. All SymIR
object are simply a Vector
of IRElement
s, which can be accessed by calling bindings
:
show(y)
%1 = x
%2 = 0.1
%3 = 3
%4 = (%1 ^ %3)
%5 = (%2 * %4)
%6 = (%1 + %5)
%7 = sin(%6)
%8 = (%7 ^ %3)
%9 = (%2 * %8)
%10 = (%7 + %9)
%11 = sin(%10)
%12 = (%11 ^ %3)
%13 = (%2 * %12)
%14 = (%11 + %13)
%15 = sin(%14)
%16 = (%15 ^ %3)
%17 = (%2 * %16)
%18 = (%15 + %17)
%19 = sin(%18)
%20 = (%19 ^ %3)
%21 = (%2 * %20)
%22 = (%19 + %21)
%23 = sin(%22)
%24 = (%23 ^ %3)
%25 = (%2 * %24)
%26 = (%23 + %25)
%27 = sin(%26)
IR.bindings(y)
27-element Vector{IRElement}:
x
0.1
3
(%1 ^ %3)
(%2 * %4)
(%1 + %5)
sin(%6)
(%7 ^ %3)
(%2 * %8)
(%7 + %9)
⋮
sin(%18)
(%19 ^ %3)
(%2 * %20)
(%19 + %21)
sin(%22)
(%23 ^ %3)
(%2 * %24)
(%23 + %25)
sin(%26)
The intermediate variables allow SymIR
objects to only grow linearly in size with respect to the number of operations, and also allows for more efficient manipulations and code generation. If we use Symbolics.jl to trace through foo
, we can observe an exponential growth in the size of the expression:
using Symbolics
@variables v
show(foo(v))
sin(sin(sin(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3)) + 0.1(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3))^3)) + 0.1(sin(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3)) + 0.1(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3))^3))^3)) + 0.1(sin(sin(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3)) + 0.1(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3))^3)) + 0.1(sin(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3)) + 0.1(sin(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3)) + 0.1(sin(sin(v + 0.1(v^3)) + 0.1(sin(v + 0.1(v^3))^3))^3))^3))^3))^3))
Manipulations
JuliaSimCompiler.jl provides a set of basic manipulations for SymIR
objects. We can:
substitute
variables to numbers or expressions,- compute
derivative
orjacobian
of an expression, and optimize
an expression.
To substitute the variable x
to a number or expression, use substitute
:
IR.substitute(y, [x => 1])
0.6635777923146026
IR.substitute(y, [x => y])
%1 = sin((x + (0.1 * (x ^ 3))))
%2 = sin((%1 + (0.1 * (%1 ^ 3))))
%3 = sin((%2 + (0.1 * (%2 ^ 3))))
%4 = sin((%3 + (0.1 * (%3 ^ 3))))
%5 = sin((%4 + (0.1 * (%4 ^ 3))))
%6 = sin((%5 + (0.1 * (%5 ^ 3))))
%7 = sin((%6 + (0.1 * (%6 ^ 3))))
%8 = sin((%7 + (0.1 * (%7 ^ 3))))
%9 = sin((%8 + (0.1 * (%8 ^ 3))))
%10 = sin((%9 + (0.1 * (%9 ^ 3))))
%11 = sin((%10 + (0.1 * (%10 ^ 3))))
%12 = sin((%11 + (0.1 * (%11 ^ 3))))
To compute the derivative of an expression, use derivative
:
dy = IR.derivative(y, x)
%1 = (x + (0.1 * (x ^ 3)))
%2 = sin(%1)
%3 = (%2 + (0.1 * (%2 ^ 3)))
%4 = sin(%3)
%5 = (%4 + (0.1 * (%4 ^ 3)))
%6 = sin(%5)
%7 = (%6 + (0.1 * (%6 ^ 3)))
%8 = sin(%7)
%9 = (%8 + (0.1 * (%8 ^ 3)))
%10 = sin(%9)
%11 = (%10 + (0.1 * (%10 ^ 3)))
%12 = (cos(%1) * ((0 + (0.1 * ((3 * (x ^ 2)) * 1))) + 1))
%13 = (cos(%3) * ((0 + (0.1 * ((3 * (%2 ^ 2)) * %12))) + %12))
%14 = (cos(%5) * ((0 + (0.1 * ((3 * (%4 ^ 2)) * %13))) + %13))
%15 = (cos(%7) * ((0 + (0.1 * ((3 * (%6 ^ 2)) * %14))) + %14))
%16 = (cos(%9) * ((0 + (0.1 * ((3 * (%8 ^ 2)) * %15))) + %15))
%17 = (cos(%11) * ((0 + (0.1 * ((3 * (%10 ^ 2)) * %16))) + %16))
Note that dy
is a SymIR
object as well, so we can compute higher order derivatives by simply chaining derivative
. To compute the 4-th derivative, we can do:
d4y′ = repeatf(y -> IR.derivative(y, x), y, 4)
nothing
The optimize
function performs peephole optimization, copy propagation, and global value numbering.
d4y = IR.optimize(d4y′)
length(IR.bindings(d4y′)), length(IR.bindings(d4y))
(4776, 488)
We can see that these passes drastically reduce the size of the intermediate representation (IR).
Code Generation
JuliaSimCompiler.jl also provides code generation utilities. We can generate Julia expression using the toexpr
function:
d4yfun = eval(IR.toexpr(d4y, x))
d4yfun(1)
-2.370014342893697
We can compare the above result against ForwardDiff.jl:
using ForwardDiff
D = f -> (x -> ForwardDiff.derivative(f, x))
repeatf(D, foo, 4)(1)
-2.370014342893697
We can use set_array
to express filing an array by in-place operations. We will show this capability by computing the Jacobian of a function:
x, y, z = IR.Var.((:x, :y, :z))
goo((x, y, z)) = [x^2 + y - z
sin(y * z) / x
hypot(x, y)]
ir = IR.set_array(goo([x, y, z]))
jac = IR.jacobian(ir, [x, y, z])
%1 = sin((y * z))
%2 = x
%3 = x
%4 = hypot(%3, y)
%5 = ((%3 * 1) / ifelse((%4 == 0), 1, %4))
%6 = setindex!(JuliaSimCompiler.Argument(1), (0 + (0 + ((2 * (x ^ 1)) * 1))), CartesianIndex(1, 1))
%7 = setindex!(JuliaSimCompiler.Argument(1), ((-(((%1 / %2) / %2)) * 1) + 0), CartesianIndex(2, 1))
%8 = setindex!(JuliaSimCompiler.Argument(1), %5, CartesianIndex(3, 1))
%9 = (y * z)
%10 = y
%11 = hypot(x, %10)
%12 = ((%10 * 1) / ifelse((%11 == 0), 1, %11))
%13 = setindex!(JuliaSimCompiler.Argument(1), (0 + 1), CartesianIndex(1, 2))
%14 = setindex!(JuliaSimCompiler.Argument(1), (0 + ((1 / x) * (cos(%9) * ((1 * z) + 0)))), CartesianIndex(2, 2))
%15 = setindex!(JuliaSimCompiler.Argument(1), %12, CartesianIndex(3, 2))
%16 = (y * z)
%17 = setindex!(JuliaSimCompiler.Argument(1), -(1), CartesianIndex(1, 3))
%18 = setindex!(JuliaSimCompiler.Argument(1), (0 + ((1 / x) * (cos(%16) * (0 + (y * 1))))), CartesianIndex(2, 3))
%19 = setindex!(JuliaSimCompiler.Argument(1), 0, CartesianIndex(3, 3))
%20 = (0 + ((1 / x) * (cos(%16) * (0 + (y * 1)))))
To generate the Julia expression, we can use toexpr
again:
jac_fun! = eval(IR.toexpr(jac, [x, y, z], check = false))
out = zeros(3, 3)
jac_fun!(out, [1, 2, 3])
3×3 Matrix{Float64}:
2.0 1.0 -1.0
0.279415 2.88051 1.92034
0.447214 0.894427 0.0
The check=false
disables the axes check of arguments in the generated function. Check the documentation of the toexpr
function for more details. We can compare the above result against ForwardDiff.jl:
ForwardDiff.jacobian(goo, [1, 2, 3])
3×3 Matrix{Float64}:
2.0 1.0 -1.0
0.279415 2.88051 1.92034
0.447214 0.894427 0.0