Google TPUs on AIchor
Some AI engineers use TPUs instead of GPUs and CPUs in their model training either due to the nature of framework used or the model type being trained.
In some cases, users have partnership or privileged access in GCP to more cost effective TPU nodes.
For these cases AIchor admins are able to import or create a GCP GKE engine with TPUs.
Prerequisites
In order to use TPU devices on AIchor, the engine must be created or imported in a GCP Project with sufficient TPU quotas and in the correct region or zone.
⚠️ The quotas must be requested from the "Compute Engine API" service and not from the "Cloud TPU" service.
For example, a correct quota request for TPU on AIchor would be:
| Field | Value |
|---|---|
| Service | Compute Engine API |
| Dimensions (e.g. location) | region:europe-west4 |
| Name | Preemptible TPU v5 Lite Podslice chips |
Learn more about TPU quotas for GKE here.
Dependencies and libraries
To use the TPU devices, the machine learning library must be installed with the libtpu in the container image.
For example with Jax:
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Or with PyTorch:
pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Start a single host TPU experiment
A single host TPU experiment, is a TPU Slice of one VM. Each TPU version has different topology designs, from this documentation you can determine if a topology is single host or multi host.
A single host TPU experiment can be started using any AIchor operator, example from the manifest
...
spec:
# example with jax. Could be any other AIchor operator
operator: "jax"
...
types:
worker:
# increasing the number of workers will request more single host TPU slices.
# As multi-slice is not yet supported, this provides multiple independent single host TPU slices.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: "tpu-v5-lite-podslice"
topology: "2x2"
tpuChipsCount: 4
In the example above a 2x2 TPU Slice on the TPU v5e requests 2x2/4=1 VM.
Start a multi host TPU experiment with kuberay
For bigger TPU slices requiring multiple hosts on a single TPU slice, kuberay operator must be used, other operators won't support multi host TPU slice.
...
spec:
operator: "kuberay"
...
types:
Head:
resources:
cpus: 4
ramRatio: 2
Workers:
- name: "tpu-pool"
# When requesting a TPU Slice, the number of VMs is implied in the given topology.
# For example this manifest will request one 4x4 TPU Slice with 4 chips per VMs, 4x4/4=4 VMs.
# If count is 2, then it will request two 4x4 TPU Slices.
# Multi-slice is not yet supported.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: "tpu-v5-lite-podslice"
topology: "4x4"
tpuChipsCount: 4
Example of code using this 4x4 v5e TPU slice
import ray
import jax
import os
@ray.remote(resources={"TPU": 4})
def tpu_cores(index: int):
device_count = jax.device_count()
return f"Index [{index}]: Global TPU Count={device_count}"
if __name__ == "__main__":
ray.init(address=os.environ.get("RAY_ADDRESS", "auto"), log_to_driver=True)
num_workers = 4 # number of pods in 4x4 topology on v5e
print(f"starting {num_workers} remote functions")
result = [tpu_cores.remote(i) for i in range(num_workers)]
print(ray.get(result))
print("Done! exiting")
ray.shutdown()
This should print
starting 4 remote functions
['Index [0]: Global TPU Count=16', 'Index [1]: Global TPU Count=16', 'Index [2]: Global TPU Count=16', 'Index [3]: Global TPU Count=16']
Done! exiting
Each worker has 4 local TPU chips and can access the 16 chips globally.
Spot instances on GKE
If the quotas are for preemptible TPU, the experiment must request spot instances. When a preemptible TPU node is reclaimed, the pods are evicted; see Recovering from eviction for how the kuberay operator restarts them automatically.
...
spec:
operator: "kuberay"
...
types:
Head: {}
Workers:
- name: "tpu-pool"
count: 1
resources:
cpus: 90
ramRatio: 2
# node selector to request spot instances on GKE.
extraSelectors:
cloud.google.com/gke-spot: "true"
cloud.google.com/gke-provisioning: spot
# on GKE, a toleration also has to be passed
extraTolerations:
- key: "cloud.google.com/gke-spot"
operator: "Equal"
value: "true"
effect: "NoSchedule"
accelerators:
tpu:
type: "tpu-v5-lite-podslice"
topology: "2x4"
tpuChipsCount: 4