Skip to main content

Google TPUs on AIchor

Some AI engineers use TPUs instead of GPUs and CPUs in their model training either due to the nature of framework used or the model type being trained.
In some cases, users have partnership or privileged access in GCP to more cost effective TPU nodes.

For these cases AIchor admins are able to import or create a GCP GKE engine with TPUs.

Prerequisites

In order to use TPU devices on AIchor, the engine must be created or imported in a GCP Project with sufficient TPU quotas and in the correct region or zone.

⚠️ The quotas must be requested from the "Compute Engine API" service and not from the "Cloud TPU" service.

For example, a correct quota request for TPU on AIchor would be:

FieldValue
ServiceCompute Engine API
Dimensions (e.g. location)region:europe-west4
NamePreemptible TPU v5 Lite Podslice chips

Learn more about TPU quotas for GKE here.

Dependencies and libraries

To use the TPU devices, the machine learning library must be installed with the libtpu in the container image.

For example with Jax:

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Or with PyTorch:

pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Start a single host TPU experiment

A single host TPU experiment, is a TPU Slice of one VM. Each TPU version has different topology designs, from this documentation you can determine if a topology is single host or multi host.

A single host TPU experiment can be started using any AIchor operator, example from the manifest

...
spec:
# example with jax. Could be any other AIchor operator
operator: "jax"
...
types:
worker:
# increasing the number of workers will request more single host TPU slices.
# As multi-slice is not yet supported, this provides multiple independent single host TPU slices.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: "tpu-v5-lite-podslice"
topology: "2x2"
tpuChipsCount: 4

In the example above a 2x2 TPU Slice on the TPU v5e requests 2x2/4=1 VM.

Start a multi host TPU experiment with kuberay

For bigger TPU slices requiring multiple hosts on a single TPU slice, kuberay operator must be used, other operators won't support multi host TPU slice.

...
spec:
operator: "kuberay"
...
types:
Head:
resources:
cpus: 4
ramRatio: 2

Workers:
- name: "tpu-pool"
# When requesting a TPU Slice, the number of VMs is implied in the given topology.
# For example this manifest will request one 4x4 TPU Slice with 4 chips per VMs, 4x4/4=4 VMs.
# If count is 2, then it will request two 4x4 TPU Slices.
# Multi-slice is not yet supported.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: "tpu-v5-lite-podslice"
topology: "4x4"
tpuChipsCount: 4

Example of code using this 4x4 v5e TPU slice

import ray
import jax
import os

@ray.remote(resources={"TPU": 4})
def tpu_cores(index: int):
device_count = jax.device_count()
return f"Index [{index}]: Global TPU Count={device_count}"

if __name__ == "__main__":
ray.init(address=os.environ.get("RAY_ADDRESS", "auto"), log_to_driver=True)

num_workers = 4 # number of pods in 4x4 topology on v5e
print(f"starting {num_workers} remote functions")

result = [tpu_cores.remote(i) for i in range(num_workers)]
print(ray.get(result))

print("Done! exiting")
ray.shutdown()

This should print

starting 4 remote functions
['Index [0]: Global TPU Count=16', 'Index [1]: Global TPU Count=16', 'Index [2]: Global TPU Count=16', 'Index [3]: Global TPU Count=16']
Done! exiting

Each worker has 4 local TPU chips and can access the 16 chips globally.

Spot instances on GKE

If the quotas are for preemptible TPU, the experiment must request spot instances. When a preemptible TPU node is reclaimed, the pods are evicted; see Recovering from eviction for how the kuberay operator restarts them automatically.

...
spec:
operator: "kuberay"
...
types:
Head: {}
Workers:
- name: "tpu-pool"
count: 1
resources:
cpus: 90
ramRatio: 2

# node selector to request spot instances on GKE.
extraSelectors:
cloud.google.com/gke-spot: "true"
cloud.google.com/gke-provisioning: spot
# on GKE, a toleration also has to be passed
extraTolerations:
- key: "cloud.google.com/gke-spot"
operator: "Equal"
value: "true"
effect: "NoSchedule"

accelerators:
tpu:
type: "tpu-v5-lite-podslice"
topology: "2x4"
tpuChipsCount: 4
  • AIchor demo project
  • Google Documentation regarding TPU
  • GCP TPU zones and regions