Skip to main content

Jax

JAX is an open-source library developed by Google Research for high-performance numerical computing and machine learning, with first-class support for automatic differentiation.

Select Jax on AIchor by setting spec.operator: jax in your manifest.yaml. AIchor schedules one pod per replica, designates the first replica as the JAX distributed coordinator, and injects the environment variables needed for jax.distributed.initialize(...) to bring the workers together into a single distributed run.

How to use

A minimal manifest selecting Jax:

kind: "AIchorManifest"
apiVersion: "0.2.3"

spec:
operator: "jax"
image: "image"
command: "python3 -u main.py --operator=jax --sleep=300 --tb-write=True"

types:
worker:
count: 2 # how many replicas of the worker pod will be deployed
resources:
cpus: 1
ramRatio: 2
accelerators: # optional
gpu:
count: 1
type: "gpu"
product: "NVIDIA-A100-SXM4-80GB"

For more complete examples (TPU, multi-GPU, …) see Manifest Examples and the full schema in the Manifest Reference.

Injected environment variables

AIchor injects the following environment variables into every worker pod. Under the hood, AIchor relies on JobSets to manage JAX distributed runs, and these variables provide each pod with the coordination details it needs. The example values below assume the two-worker manifest above, where the coordinator is the first pod:

VariableDescriptionExample
JAXOPERATOR_COORDINATOR_ADDRESSAddress of the coordinator (the first pod), including the host and the port it listens on. This is the value to pass to jax.distributed.initialize.worker-0-0.experiment-6f4a:1234
JAXOPERATOR_COORDINATOR_HOSTThe coordinator pod's hostname, without a port. Read from the jobset.sigs.k8s.io/coordinator pod label that JobSet stamps on every pod (source). JobSet designates which pod is the coordinator; this variable only exposes that pod's address.worker-0-0.experiment-6f4a
JAXOPERATOR_NUM_PROCESSESTotal number of Jax pods running in parallel (equals the worker count).2
JAXOPERATOR_PROCESS_IDRank of the current pod among all pods in the job, from 0 to JAXOPERATOR_NUM_PROCESSES - 1. The coordinator is always 0; every other pod gets a distinct non-zero rank.0 on the coordinator, 1 on the second worker
note

JAXOPERATOR_COORDINATOR_ADDRESS is set to the literal value $(JAXOPERATOR_COORDINATOR_HOST):1234, which a shell expands to the coordinator host plus port 1234 when the command starts. If the pod image cannot evaluate that expression, append a port manually to JAXOPERATOR_COORDINATOR_HOST (for example :1234) and use that as the coordinator listening address.

Demo project

The AIchor team maintains a demo project that can be cloned and used as a starting point for Jax experiments on AIchor:

Further reading