Skip to content

gomlx/gopjrt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

gopjrt (Installing)

GoDev GitHub Go Report Card TestStatus Coverage

gopjrt leverages OpenXLA to compile, optimize and accelerate numeric computations (with large data) from Go using various backends supported by OpenXLA: CPU, GPUs (NVidia, Intel*, Apple Metal*) and TPU*. It can be used to power Machine Learning frameworks (e.g. GoMLX), image processing, scientific computation, game AIs, etc.

NEW: Experimental, and somewhat limited Apple/Metal support.

And because Jax, TensorFlow and optionally PyTorch run on XLA, it is possible to run Jax functions in Go with gopjrt, and probably TensorFlow and PyTorch as well. See example 2 below.

(*) Not tested yet, pls let me know if it works for you, or if you can lend access to these hardware (a virtual machine) so that I can use (a virtual machine) for a while, I would love to try to verify and make sure it works there.

gopjrt aims to be minimalist and robust: it provides well maintained, extensible Go wrappers for OpenXLA PJRT and OpenXLA XlaBuilder libraries.

It is not very ergonomic (error handling everywhere), and the expectation is that others will create a friendlier API on top of gopjrt -- the same way Jax is a friendlier API on top of XLA/PJRT.

One such friendlier API is GoMLX, a Go machine learning framework, but gopjrt may be used as a standalone, for lower level access to XLA and other accelerator use cases -- like running Jax functions in Go.

It provides 2 independent packages (often used together, but not necessarily):

github.com/gomlx/gopjrt/pjrt

This package loads PJRT plugins -- implementations of PJRT for specific hardware (CPU, GPUs, TPUs, etc.) in the form of a dynamic linked library -- and provides an API to compile and execute "programs".

"Programs" for PJRT are specified as "StableHLO serialized proto-buffers" (HloModuleProto more specifically). This is an intermediary representation (IR) not usually written directly by humans that can be output by, for instance, a Jax/PyTorch/Tensorflow program, or using the xlabuilder package described below.

It includes the following main concepts:

  • Client: first thing created after loading a plugin. It seems one can create a singleton Client per plugin, it's not very clear to me why one would create more than one Client.
  • LoadedExecutable: Created when one calls Client.Compile an HLO program. It's the compiled/optimized/accelerated code ready to run.
  • Buffer: Represents a buffer with the input/output data for the computations in the accelerators. There are methods to transfer it to/from the host memory. They are the inputs and outputs of LoadedExecutable.Execute.

While it uses CGO to dynamically load the plugin and call its C API, pjrt doesn't require anything other than the plugin to be installed.

The project release includes 2 plugins, one for CPU (linux-x86) compiled from XLA source code, and one for GPUs provided in the Jax distributed binaries -- both for linux/x86-64 architecture (help with Mac wanted!). But there are instructions to build your own CPU plugin (e.g.: for a different architecture), or GPU (XLA seems to have code to support ROCm, but I'm not sure of the status). And it should work with binary plugins provided by others -- see plugins references in PJRT blog post.

github.com/gomlx/gopjrt/xlabuilder

This provides a Go API for build accelerated computation using the XLA Operations. The output of building the computation using xlabuilder is an StableHLO(-ish) program that can be directly used with PJRT (and the pjrt package above).

Again it aims to be minimalist, robust and well maintained, albeit not very ergonomic necessarily.

Main concepts:

  • XlaBuilder: builder object, used to keep track of the operations being added.
  • XlaComputation: created with XlaBuilder.Build(...) and represents the finished program, ready to be used by PJRT (or saved to disk). It is also used to represent sub-routines/functions -- see XlaBuilder.CreateSubBuilder and Call method.
  • Literal: represents constants in the program. Some similarities with a pjrt.Buffer, but Literal is only used during the creation of the program. Usually, better to avoid large constants in a program, rather feed them as pjrt.Buffer, as inputs to the program during its execution.

See examples below.

The xlabuilder package includes a separate C project that generates a libgomlx_xlabuilder.so dynamic library (~13Mb for linux/x86-64) and associated *.h files, that need to be installed. A tar.gz is included in the release for linux/x86-64 architecture (help for Macs wanted!). But one can also build it from scratch for different platforms -- it uses Bazel due to its dependencies to OpenXLA/XLA.

Notice that there are alternatives to using XlaBuilder:

  • JAX/TensorFlow can output the HLO of JIT compiled functions, that can be fed directly to PJRT (see example 2).
  • Use GoMLX.
  • One can use XlaBuilder during development, and then save the output (see XlaComputation.SerializedHLO). And then during production only use the pjrt package to execute it.

Examples

  • This is a trivial example. XLA/PJRT really shines when doing large number crunching tasks.
  • The package github.com/janpfeifer/must simply converts errors to panics.
  builder := xlabuilder.New("x*x+1")
  x := must.M1(xlabuilder.Parameter(builder, "x", 0, xlabuilder.MakeShape(dtypes.F32))) // Scalar float32.
  fX := must.M1(xlabuilder.Mul(x, x))
  one := must.M1(xlabuilder.ScalarOne(builder, dtypes.Float32))
  fX = must.M1(xlabuilder.Add(fX, one))

  // Get computation created.
  comp := must.M1(builder.Build(fX))
  //fmt.Printf("HloModule proto:\n%s\n\n", comp.TextHLO())

  // PJRT plugin and create a client.
  plugin := must.M1(pjrt.GetPlugin(*flagPluginName))
  fmt.Printf("Loaded %s\n", plugin)
  client := must.M1(plugin.NewClient(nil))

  // Compile program.
  loadedExec := must.M1(client.Compile().WithComputation(comp).Done())
  fmt.Printf("Compiled program: name=%s, #outputs=%d\n", loadedExec.Name, loadedExec.NumOutputs)
	
  // Test values:
  inputs := []float32{0.1, 1, 3, 4, 5}
  fmt.Printf("f(x) = x^2 + 1:\n")
  for _, input := range inputs {
    inputBuffer := must.M1(pjrt.ScalarToBuffer(client, input))
    outputBuffers := must.M1(loadedExec.Execute(inputBuffer).Done())
    output := must.M1(pjrt.BufferToScalar[float32](outputBuffers[0]))
    fmt.Printf("\tf(x=%g) = %g\n", input, output)
  }

  // Destroy the client and leave.
  must.M1(client.Destroy())

Example 2: Execute Jax function in Go with pjrt:

First we create the HLO program in Jax/Python (see Jax documentation)

(You can do this with Google's Colab without having to install anything)

import os
import jax

def f(x): 
  return x*x+1

comp = jax.xla_computation(f)(3.)
print(comp.as_hlo_text())
hlo_proto = comp.as_hlo_module()

with open('hlo.pb', 'wb') as file:
  file.write(hlo_proto.as_serialized_hlo_module_proto())

Then download the hlo.pb file and do:

  hloBlob := must.M1(os.ReadFile("hlo.pb"))

  // PJRT plugin and create a client.
  plugin := must.M1(pjrt.GetPlugin(*flagPluginName))
  fmt.Printf("Loaded %s\n", plugin)
  client := must.M1(plugin.NewClient(nil))
  loadedExec := must.M1(client.Compile().WithHLO(hloBlob).Done())

  // Test values:
  inputs := []float32{0.1, 1, 3, 4, 5}
  fmt.Printf("f(x) = x^2 + 1:\n")
  for _, input := range inputs {
    inputBuffer := must.M1(pjrt.ScalarToBuffer(client, input))
    outputBuffers := must.M1(loadedExec.Execute(inputBuffer).Done())
    output := must.M1(pjrt.BufferToScalar[float32](outputBuffers[0]))
    fmt.Printf("\tf(x=%g) = %g\n", input, output)
  }

The notebook includes both the "regular" Go implementation and the corresponding implementation using XlaBuilder and execution with PJRT for comparison, with some benchmarks.

Installing

TLDR;

gopjrt requires a C library installed and a plugin module.

For Linux or Windows+WSL, run the following script (see source) to install under /usr/local/{lib,include}:

curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_linux_amd64.sh | bash

For Linux (or Windows+WSL)+CUDA (NVidia GPU) support, in addition also run (see source):

curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_cuda.sh | bash

For Darwin/arm64 (M1, M2) GPU support, run the following script (see source) to install under /usr/local/{lib,include}:

  • VERY EXPERIMENTAL: only a subset of the operations and types supported (float64 doesn't work). See https://developer.apple.com/metal/jax/. And the CPU version of XLA is not working either. More of a gopjrt developer version.
curl -sSf https://raw.githubusercontent.com/gomlx/gopjrt/main/cmd/install_darwin_arm64.sh | bash

TODO(Darwin): Create a Homebrew version.

That's it. The next sections explains in more details for those interested in special cases.

More details

The the install scripts cmd/install_linux_amd64.sh, cmd/install_cuda.sh and cmd/install_darwin_arm64.sh can be controlled to install in any arbitrary directory (by setting GOPJRT_INSTALL_DIR) and not to use sudo (by setting GOPJRT_NOSUDO). You many need to fiddle with LD_LIBRARY_PATH if the installation directory is not standard, and the PJRT_PLUGIN_LIBRARY_PATH to tell gopjrt where to find the plugins.

There are two parts that needs installing: (1) XLA Builder library (it's a C++ wrapper); (2) PJRT plugins for the accelerator devices you want to support.

The releases come with a prebuilt:

  1. XLA Builder library for linux/amd64 (or Windows WSL), darwin/arm64 and darwin/amd64. Both MacOS/Darwin releases are EXPERIMENTAL and have somewhat limited functionality (on the PJRT side), see https://developer.apple.com/metal/jax/.
  2. The PJRT for CPU only for linux/amd64.

The installation scripts download the Linux/CUDA PJRT or the Darwin/arm64 and Darwin/amd64 PJRT from the corresponding Jax pip package.

Installing XLA Builder

If you have any questions, or want a custom installation of hte XLA Builder library, check and modify cmd/install_linux_amd64.sh, cmd/install_cuda.sh or cmd/install_darwin_arm64.sh (VERY EXPERIMENTAL, GPU ONLY) they are self-explaining.

Installing PJRT plugins

The recommended location for plugins is /usr/local/lib/gomlx/pjrt, and that's where the installation scripts cmd/install_linux_amd64.sh, cmd/install_cuda.sh and cmd/install_darwin_arm64.sh install them.

But gopjrt will automatically search for PJRT plugins in all standard library locations (configured in /etc/ld.so.conf in Linux). Alternatively, one can set the directory(ies) to search for plugins setting the environment variable PJRT_PLUGIN_LIBRARY_PATH.

Plugins for other devices or platforms.

See docs/devel.md on hints on how to compile a plugin from OpenXLA/XLA sources.

Also, see this blog post with the link and references to the Intel and Apple hardware plugins.

FAQ

  • When is feature X from PJRT or XlaBuilder going to be supported ? Yes, gopjrt doesn't wrap everything -- although it does cover the most common operations. The simple ops and structs are auto-generated. But many require hand-writing. Please if it is useful to your project, create an issue, I'm happy to add it. I focused on the needs of GoMLX, but the idea is that it can serve other purposes, and I'm happy to support it.
  • Why not split in smaller packages ? Because of golang/go#13467 : C API's cannot be exported across packages, even within the same repo. Even a function as simple as func Add(a, b C.int) C.int in one package cannot be called from another. So we need to wrap everything, and more than that, one cannot create separate sub-packages to handle separate concerns. THis is also the reason the library chelper.go is copied in both pjrt and xlabuilder packages.
  • Why does PJRT spits out so much logging ? Can we disable it ? This is a great question ... imagine if every library we use decided they also want to clutter our stderr? I have an open question in Abseil about it. It may be some issue with Abseil Logging which also has this other issue of not allowing two different linked programs/libraries to call its initialization (see Issue #1656). A hacky work around is duplicating fd 2 and assign to Go's os.Stderr, and then close fd 2, so PJRT plugins won't have where to log. This hack is encoded in the function pjrt.SuppressAbseilLoggingHack(): just call it before calling pjrt.GetPlugin. But it may have unintended consequences, if some other library is depending on the fd 2 to work, or if a real exceptional situation needs to be reported and is not.

Links to documentation

Acknowledgements

This project utilizes the following components from the OpenXLA project:

  • This project includes a (slightly modified) copy of the OpenXLA's pjrt_c_api.h file.

  • OpenXLA PJRT CPU Plugin: This plugin enables execution of XLA computations on the CPU.

  • OpenXLA PJRT CUDA Plugin: This plugin enables execution of XLA computations on NVIDIA GPUs.

  • We gratefully acknowledge the OpenXLA team for their valuable work in developing and maintaining these plugins.

Licensing:

gopjrt is licensed under the Apache 2.0 license.

The OpenXLA project, including pjrt_c_api.h file, the CPU and CUDA plugins, is licensed under the Apache 2.0 license.

The CUDA plugin also utilizes the NVIDIA CUDA Toolkit, which is subject to NVIDIA's licensing terms and must be installed by the user.

For more information about OpenXLA, please visit their website at openxla.org, or the github page at github.com/openxla/xla