gopjrt (Installing)
gopjrt
leverages OpenXLA to compile, optimize and accelerate numeric
computations (with large data) from Go using various backends supported by OpenXLA: CPU, GPUs (NVidia, Intel*, Apple Metal*) and TPU*.
It can be used to power Machine Learning frameworks (e.g. GoMLX), image processing, scientific
computation, game AIs, etc.
NEW: Experimental, and somewhat limited Apple/Metal support.
And because Jax, TensorFlow and optionally PyTorch run on XLA,
it is possible to run Jax functions in Go with gopjrt
, and probably TensorFlow and PyTorch as well.
See example 2 below.
(*) Not tested yet, pls let me know if it works for you, or if you can lend access to these hardware (a virtual machine) so that I can use (a virtual machine) for a while, I would love to try to verify and make sure it works there.
gopjrt
aims to be minimalist and robust: it provides well maintained, extensible Go wrappers for
OpenXLA PJRT and OpenXLA XlaBuilder libraries.
It is not very ergonomic (error handling everywhere), and the expectation is that others will create a
friendlier API on top of gopjrt
-- the same way Jax is a friendlier API
on top of XLA/PJRT.
One such friendlier API is GoMLX, a Go machine learning framework, but gopjrt
may be used as a standalone,
for lower level access to XLA and other accelerator use cases -- like running Jax functions in Go.
It provides 2 independent packages (often used together, but not necessarily):
This package loads PJRT plugins -- implementations of PJRT for specific hardware (CPU, GPUs, TPUs, etc.) in the form of a dynamic linked library -- and provides an API to compile and execute "programs".
"Programs" for PJRT are specified as "StableHLO serialized proto-buffers" (HloModuleProto
more specifically).
This is an intermediary representation (IR) not usually written directly by humans that can be output by,
for instance, a Jax/PyTorch/Tensorflow program, or using the xlabuilder
package described below.
It includes the following main concepts:
Client
: first thing created after loading a plugin. It seems one can create a singletonClient
per plugin, it's not very clear to me why one would create more than oneClient
.LoadedExecutable
: Created when one callsClient.Compile
an HLO program. It's the compiled/optimized/accelerated code ready to run.Buffer
: Represents a buffer with the input/output data for the computations in the accelerators. There are methods to transfer it to/from the host memory. They are the inputs and outputs ofLoadedExecutable.Execute
.
While it uses CGO to dynamically load the plugin and call its C API, pjrt
doesn't require anything other than the plugin
to be installed.
The project release includes 2 plugins, one for CPU (linux-x86) compiled from XLA source code, and one for GPUs provided in the Jax distributed binaries -- both for linux/x86-64 architecture (help with Mac wanted!). But there are instructions to build your own CPU plugin (e.g.: for a different architecture), or GPU (XLA seems to have code to support ROCm, but I'm not sure of the status). And it should work with binary plugins provided by others -- see plugins references in PJRT blog post.
This provides a Go API for build accelerated computation using the XLA Operations.
The output of building the computation using xlabuilder
is an StableHLO(-ish)
program that can be directly used with PJRT (and the pjrt
package above).
Again it aims to be minimalist, robust and well maintained, albeit not very ergonomic necessarily.
Main concepts:
XlaBuilder
: builder object, used to keep track of the operations being added.XlaComputation
: created withXlaBuilder.Build(...)
and represents the finished program, ready to be used by PJRT (or saved to disk). It is also used to represent sub-routines/functions -- seeXlaBuilder.CreateSubBuilder
andCall
method.Literal
: represents constants in the program. Some similarities with apjrt.Buffer
, butLiteral
is only used during the creation of the program. Usually, better to avoid large constants in a program, rather feed them aspjrt.Buffer
, as inputs to the program during its execution.
See examples below.
The xlabuilder
package includes a separate C project that generates a libgomlx_xlabuilder.so
dynamic library
(~13Mb for linux/x86-64) and associated *.h
files, that need to be installed. A tar.gz
is included in the release
for linux/x86-64 architecture (help for Macs wanted!).
But one can also build it from scratch for different platforms -- it uses Bazel due to its dependencies to OpenXLA/XLA.
Notice that there are alternatives to using XlaBuilder
:
- JAX/TensorFlow can output the HLO of JIT compiled functions, that can be fed directly to PJRT (see example 2).
- Use GoMLX.
- One can use
XlaBuilder
during development, and then save the output (seeXlaComputation.SerializedHLO
). And then during production only use thepjrt
package to execute it.
- This is a trivial example. XLA/PJRT really shines when doing large number crunching tasks.
- The package
github.com/janpfeifer/must
simply converts errors to panics.
builder := xlabuilder.New("x*x+1")
x := must.M1(xlabuilder.Parameter(builder, "x", 0, xlabuilder.MakeShape(dtypes.F32))) // Scalar float32.
fX := must.M1(xlabuilder.Mul(x, x))
one := must.M1(xlabuilder.ScalarOne(builder, dtypes.Float32))
fX = must.M1(xlabuilder.Add(fX, one))
// Get computation created.
comp := must.M1(builder.Build(fX))
//fmt.Printf("HloModule proto:\n%s\n\n", comp.TextHLO())
// PJRT plugin and create a client.
plugin := must.M1(pjrt.GetPlugin(*flagPluginName))
fmt.Printf("Loaded %s\n", plugin)
client := must.M1(plugin.NewClient(nil))
// Compile program.
loadedExec := must.M1(client.Compile().WithComputation(comp).Done())
fmt.Printf("Compiled program: name=%s, #outputs=%d\n", loadedExec.Name, loadedExec.NumOutputs)
// Test values:
inputs := []float32{0.1, 1, 3, 4, 5}
fmt.Printf("f(x) = x^2 + 1:\n")
for _, input := range inputs {
inputBuffer := must.M1(pjrt.ScalarToBuffer(client, input))
outputBuffers := must.M1(loadedExec.Execute(inputBuffer).Done())
output := must.M1(pjrt.BufferToScalar[float32](outputBuffers[0]))
fmt.Printf("\tf(x=%g) = %g\n", input, output)
}
// Destroy the client and leave.
must.M1(client.Destroy())
First we create the HLO program in Jax/Python (see Jax documentation)
(You can do this with Google's Colab without having to install anything)
import os
import jax
def f(x):
return x*x+1
comp = jax.xla_computation(f)(3.)
print(comp.as_hlo_text())
hlo_proto = comp.as_hlo_module()
with open('hlo.pb', 'wb') as file:
file.write(hlo_proto.as_serialized_hlo_module_proto())
Then download the hlo.pb
file and do:
- (The package
github.com/janpfeifer/must
simply converts errors to panics)
hloBlob := must.M1(os.ReadFile("hlo.pb"))
// PJRT plugin and create a client.
plugin := must.M1(pjrt.GetPlugin(*flagPluginName))
fmt.Printf("Loaded %s\n", plugin)
client := must.M1(plugin.NewClient(nil))
loadedExec := must.M1(client.Compile().WithHLO(hloBlob).Done())
// Test values:
inputs := []float32{0.1, 1, 3, 4, 5}
fmt.Printf("f(x) = x^2 + 1:\n")
for _, input := range inputs {
inputBuffer := must.M1(pjrt.ScalarToBuffer(client, input))
outputBuffers := must.M1(loadedExec.Execute(inputBuffer).Done())
output := must.M1(pjrt.BufferToScalar[float32](outputBuffers[0]))
fmt.Printf("\tf(x=%g) = %g\n", input, output)
}
Example 3: Mandelbrot Set Notebook
The notebook includes both the "regular" Go implementation and the corresponding implementation using XlaBuilder
and execution with PJRT
for comparison, with some benchmarks.
gopjrt requires a C library installed and a plugin module.
For Linux or Windows+WSL, run the following script (see source) to install under /usr/local/{lib,include}
:
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_linux_amd64.sh | bash
For Linux (or Windows+WSL)+CUDA (NVidia GPU) support, in addition also run (see source):
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_cuda.sh | bash
For Darwin/arm64 (M1, M2) GPU support, run the following script (see source) to install under /usr/local/{lib,include}
:
- VERY EXPERIMENTAL: only a subset of the operations and types supported (
float64
doesn't work). See https://developer.apple.com/metal/jax/. And the CPU version of XLA is not working either. More of agopjrt
developer version.
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_darwin_arm64.sh | bash
TODO(Darwin): Create a Homebrew version.
That's it. The next sections explains in more details for those interested in special cases.
The the install scripts cmd/install_linux_amd64.sh
,
cmd/install_cuda.sh
and
cmd/install_darwin_arm64.sh
can be controlled to install in any arbitrary directory (by setting GOPJRT_INSTALL_DIR
) and not to use sudo
(by setting GOPJRT_NOSUDO
).
You many need to fiddle with LD_LIBRARY_PATH
if the installation directory is not standard, and the PJRT_PLUGIN_LIBRARY_PATH
to tell gopjrt where to find the plugins.
There are two parts that needs installing: (1) XLA Builder library (it's a C++ wrapper); (2) PJRT plugins for the accelerator devices you want to support.
The releases come with a prebuilt:
- XLA Builder library for linux/amd64 (or Windows WSL), darwin/arm64 and darwin/amd64. Both MacOS/Darwin releases are EXPERIMENTAL and have somewhat limited functionality (on the PJRT side), see https://developer.apple.com/metal/jax/.
- The PJRT for CPU only for linux/amd64.
The installation scripts download the Linux/CUDA PJRT or the Darwin/arm64 and Darwin/amd64 PJRT from the corresponding Jax pip package.
If you have any questions, or want a custom installation of hte XLA Builder library, check and modify
cmd/install_linux_amd64.sh
,
cmd/install_cuda.sh
or
cmd/install_darwin_arm64.sh
(VERY EXPERIMENTAL, GPU ONLY)
they are self-explaining.
The recommended location for plugins is /usr/local/lib/gomlx/pjrt
, and that's where the installation scripts
cmd/install_linux_amd64.sh
,
cmd/install_cuda.sh
and
cmd/install_darwin_arm64.sh
install them.
But gopjrt will automatically search for PJRT plugins in all standard library locations (configured in /etc/ld.so.conf
in Linux).
Alternatively, one can set the directory(ies) to search for plugins setting the environment variable
PJRT_PLUGIN_LIBRARY_PATH
.
See docs/devel.md on hints on how to compile a plugin from OpenXLA/XLA sources.
Also, see this blog post with the link and references to the Intel and Apple hardware plugins.
- When is feature X from PJRT or XlaBuilder going to be supported ?
Yes,
gopjrt
doesn't wrap everything -- although it does cover the most common operations. The simple ops and structs are auto-generated. But many require hand-writing. Please if it is useful to your project, create an issue, I'm happy to add it. I focused on the needs of GoMLX, but the idea is that it can serve other purposes, and I'm happy to support it. - Why not split in smaller packages ?
Because of golang/go#13467 : C API's cannot be exported across packages, even within the same repo.
Even a function as simple as
func Add(a, b C.int) C.int
in one package cannot be called from another. So we need to wrap everything, and more than that, one cannot create separate sub-packages to handle separate concerns. THis is also the reason the librarychelper.go
is copied in bothpjrt
andxlabuilder
packages. - Why does PJRT spits out so much logging ? Can we disable it ?
This is a great question ... imagine if every library we use decided they also want to clutter our stderr?
I have an open question in Abseil about it.
It may be some issue with Abseil Logging which also has this other issue
of not allowing two different linked programs/libraries to call its initialization (see Issue #1656).
A hacky work around is duplicating fd 2 and assign to Go's
os.Stderr
, and then close fd 2, so PJRT plugins won't have where to log. This hack is encoded in the functionpjrt.SuppressAbseilLoggingHack()
: just call it before callingpjrt.GetPlugin
. But it may have unintended consequences, if some other library is depending on the fd 2 to work, or if a real exceptional situation needs to be reported and is not.
- Google Drive Directory with Design Docs: Some links are outdated or redirected, but very valuable information.
- How to use the PJRT C API? #openxla/xla/issues/7038: discussion of folks trying to use PJRT in their projects. Some examples leveraging some of the XLA C++ library.
- How to use PJRT C API v.2 #openxla/xla/issues/7038.
- PJRT C API README.md: a collection of links to other documents.
- Public Design Document.
- Gemini helped quite a bit parsing/understanding things -- despite the hallucinations -- other AIs may help as well.
This project utilizes the following components from the OpenXLA project:
-
This project includes a (slightly modified) copy of the OpenXLA's
pjrt_c_api.h
file. -
OpenXLA PJRT CPU Plugin: This plugin enables execution of XLA computations on the CPU.
-
OpenXLA PJRT CUDA Plugin: This plugin enables execution of XLA computations on NVIDIA GPUs.
-
We gratefully acknowledge the OpenXLA team for their valuable work in developing and maintaining these plugins.
gopjrt is licensed under the Apache 2.0 license.
The OpenXLA project, including pjrt_c_api.h
file, the CPU and CUDA plugins, is licensed under the Apache 2.0 license.
The CUDA plugin also utilizes the NVIDIA CUDA Toolkit, which is subject to NVIDIA's licensing terms and must be installed by the user.
For more information about OpenXLA, please visit their website at openxla.org, or the github page at github.com/openxla/xla