I want to express my sincere gratitude for your guidance, support, and mentorship throughout the Google Summer of Code program. Your expertise and patience have been invaluable in helping me navigate challenges and grow as a developer. I am truly grateful for this incredible learning experience and the opportunity to work under your guidance. Thank you for your time, insights, and dedication throughout the project!
Modern hypotheses in neuroscience suggest that the information about the world, body movements or higher-level cognition are encoded by large groups of neurons. Activity patterns of neurons have so far been examined both single and pairs. Still, it remains largely unexplored how and which information is encoded by larger groups of neurons. Higher order interactions (HOIs), or non-pairwise interactions between neurons, are being studied in an effort to understand how the information is encoded by multiplets of neurons i.e. groups of 3, 4, ..., N neurons. Mathematical breakthrough in the field of information theory recently introduced a measure called the "organisational structure" (O-info) to quantify and characterize the information carried by multiplets. However, measuring HOIs is complex because of the possible number of multiplets. For example, with 10 neurons, 720 triplets, 5040 quadraplets and 30,240 quintuplets have to be computed. The goal of my project was to implement and optimize the computation of HOIs using a recent library called Jax. Here are some key advantages of JAX over NumPy:
GPU/TPU Acceleration: JAX integrates with accelerators like GPUs and TPUs, allowing us to perform computations on these devices without requiring extensive modifications to the code. NumPy, on the other hand, doesn't provide this level of hardware acceleration by default.
Function JIT Compilation: JAX uses Just-In-Time (JIT) compilation to speed up computations. This means that functions are compiled for execution only during their first run, resulting in significant performance improvements, especially for repeated calculations, since lesser code remains to compile each time.
Parallelism and Concurrency: JAX offers primitives for parallel and concurrent execution, which improves the efficiency of computations on multi-core systems.
Interoperability with NumPy: While JAX has its own array library, it's designed to be compatible with NumPy. This means we can often replace NumPy arrays with JAX arrays in our code, leveraging JAX's advantages without a complete rewrite.
Transformations and Compilers: JAX enables program transformations through its `jit`, `vmap`, and `grad` functions. These allow you to optimize, vectorize, and differentiate functions more efficiently.