Google TPUs on AIchor
Some AI engineers tend to use TPUs rather than GPUs and CPUs in their model training either due to the nature of framework used or the model type being trained.
In some cases, users have partnership or privileged access in GCP to more cost effective TPU nodes.
It is in this context, that AIchor users are able to import or create TPU based engines.
Prerequisites
In order to use TPU devices on AIchor, the dataplane must be created in a GCP Project with sufficient TPU quotas and in the correct region or zone.
⚠️ The quotas must be requested from the "Compute Engine API" service and not from the "Cloud TPU" service.
For example, a correct quota request for TPU on AIchor would be:
Service: Compute Engine APIDimensions (e.g. location): region:europe-west4Name: Preemptible TPU v5 Lite Podslice chips
Learn more about TPU quotas for GKE here.
Creating the dataplane
A TPU cluster can be created from AIchor UI. Make sure to select the correct GCP Project and region according to the quotas.
Importing the dataplane
AIchor provides the option to import an already existing GKE cluster. To use TPUs on it, either the node autoprovisioner must be able to provision TPU nodepool (ref) or create a node pool, autoscaled or not (ref).
Dependencies and libraries
To use the TPU devices, the machine learning library must be installed with the libtpu in the container image.
For example with Jax:
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Or with PyTorch:
pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Start a single host TPU experiment
A single host TPU experiment, is a TPU Slice of one VM. Each TPU version has different topology designs, from this documentation you can determine if a topology is single host or multi host.
A single host TPU experiment can be started using any AIchor operator, example from the manifest
...
spec:
# example with jax. Could be any other AIchor operator
operator: jax
...
types:
Worker:
# increasing the number Worker will request more single host TPU slices.
# As multi-slice is not yet supported, it will give you multiple independent single host TPU Slices.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: tpu-v5-lite-podslice
topology: 2x2
tpuChipsCount: 4
In the example above a 2x2 TPU Slice on the TPU v5e requests 2x2/4=1 VM.
Start a multi host TPU experiment with kuberay
For bigger TPU slices requiring multiple hosts on a single TPU slice, kuberay operator must be used, other operator won't support multi host TPU slice.
...
spec:
# must use kuberay
operator: kuberay
...
types:
Head:
resources:
cpus: 4
ramRatio: 2
Workers:
- name: tpu-pool
# When requesting a TPU Slice, the number of VMs is implied in the given topology.
# For example this manifest will request one 4x4 TPU Slice with 4 chips per VMs, 4x4/4=4 VMs.
# If count is 2, then it will request two 4x4 TPU Slices.
# Multi-slice is not yet supported.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: tpu-v5-lite-podslice
topology: 4x4
tpuChipsCount: 4
Example of code using this 4x4 v5e TPU slice
import ray
import jax
import os
@ray.remote(resources={"TPU": 4})
def tpu_cores(index: int):
device_count = jax.device_count()
return f"Index [{index}]: Global TPU Count={device_count}"
if __name__ == "__main__":
ray.init(address=os.environ.get("RAY_ADDRESS", "auto"), log_to_driver=True)
num_workers = 4 # number of pods in 4x4 topology on v5e
print(f"starting {num_workers} remote functions")
result = [tpu_cores.remote(i) for i in range(num_workers)]
print(ray.get(result))
print("Done! exiting")
ray.shutdown()
This should print
starting 4 remote functions
['Index [0]: Global TPU Count=16', 'Index [1]: Global TPU Count=16', 'Index [2]: Global TPU Count=16', 'Index [3]: Global TPU Count=16']
Done! exiting
Each worker has 4 local TPU chips and can access the 16 chips globally.
Spot instances on GKE
If the quotas are for preemptible TPU, the experiment must request spot instances:
...
spec:
operator: kuberay
...
types:
Head: {}
Workers:
- name: tpu-pool
count: 1
resources:
cpus: 90
ramRatio: 2
# node selector to request spot instances on GKE.
extraSelectors:
cloud.google.com/gke-spot: "true"
cloud.google.com/gke-provisioning: spot
# on GKE, a toleration also has to be passed
extraTolerations:
- key: "cloud.google.com/gke-spot"
operator: "Equal"
value: "true"
effect: "NoSchedule"
accelerators:
tpu:
type: tpu-v5-lite-podslice
topology: 2x4
tpuChipsCount: 4
Demo project and other documentation links
- AIchor demo project: https://github.com/instadeepai/aichor-demo/tree/main/tpu/kuberay-multi-host
- Google Documentation regarding TPU: https://cloud.google.com/kubernetes-engine/docs/concepts/tpus#type-node-pool
- GCP TPU zones and regions: https://cloud.google.com/tpu/docs/regions-zones#europe