Speeding Up Your Python Code With Numba#
Python is slow partially because it is not being compiled - i.e. it is an interpretive language. One solution to this is to compile it. Unfortunately, it’s difficult to compile everything that one can do in python. That said, there are ways to compile functions that use a subset of python. One tool that allows you to do that is numba
Here we will see how to use numba to compile an important subset of functions which will then allow our code to run significantly faster. Include numba by doing import numba
The key trick to use numba is to add @njit
on top of functions that you want it to compile - i.e.
@njit
def myPreviouslySlowFunction(beta, Nx, Ny):
# here I am doing stuff
return myAnswer
will now be compiled (just in time - jit) when you run the program and so, if you call it many times, this will significantly speed up your simulation. This is especially the case if your function has for
loops in it.
Three useful hints here:
There are some things that numba doesn’t like to compile - objects, named parameters in functions, etc. If you try to run it and it complains that it is not happy compiling it, try to (1) change out the things it is unhappy about and (2) see if you can factor out the slow parts into smaller functions that don’t do the disallowed things
If you’re compiling a function, you probably also want to compile the functions it calls. So generally I start by getting the innermost functions compiling and then work outwards from there.
The trick with optimization is to first get something working and then speed up the things that are slow. Work one step at a time and don’t worry about making everything fast. Most things don’t matter.
Numba for the Ising Model#
I ended up needing to do the following to get numba to significantly speed up my Ising model. Currently your code probably looks something like this:
def deltaE(spins,flipX,flipY):
# do stuff
return change_in_energy
def Energy(spins):
# do stuff
return myEnergy
for sweep in range(0,10000):
for step in range(0,N):
#flip spins
# measure
I did the following. I got it to compile my deltaE and Energy functions. I then pulled out that for
loop over steps and turned it into a Sweep
function and also had it compile that. Later on in the assignment I also compiled the coarse_grain
function which returns the coarse-grained lattice. So I had
@njit
def deltaE(spins,flipX,flipY):
# do stuff
return change_in_energy
@njit
def Energy(spins):
# do stuff
return myEnergy
@njit
def Sweeps(spins):
for step in range(0,N):
#flip spins
for sweep in range(0,10000):
Sweep(spins)
# measure
This was sufficient for me to be able to do approximately 10,000 sweeps on an \(81 \times 81\) lattice in about ten seconds on google colab.