Skip to main content

Google TPUs on AIchor

Some AI engineers tend to use TPUs rather than GPUs and CPUs in their model training either due to the nature of framework used or the model type being trained.
In some cases, users have partnership or privileged access in GCP to more cost effective TPU nodes.

It is in this context, that AIchor users are able to import or create TPU based engines.

Prerequisites

In order to use TPU devices on AIchor, the dataplane must be created in a GCP Project with sufficient TPU quotas and in the correct region or zone.

⚠️ The quotas must be requested from the "Compute Engine API" service and not from the "Cloud TPU" service.

For example, a correct quota request for TPU on AIchor would be:

  • Service: Compute Engine API
  • Dimensions (e.g. location): region:europe-west4
  • Name: Preemptible TPU v5 Lite Podslice chips

Learn more about TPU quotas for GKE here.

Creating the dataplane

A TPU cluster can be created from AIchor UI. Make sure to select the correct GCP Project and region according to the quotas.

Importing the dataplane

AIchor provides the option to import an already existing GKE cluster. To use TPUs on it, either the node autoprovisioner must be able to provision TPU nodepool (ref) or create a node pool, autoscaled or not (ref).

Dependencies and libraries

To use the TPU devices, the machine learning library must be installed with the libtpu in the container image.

For example with Jax:

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Or with PyTorch:

pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Start a single host TPU experiment

A single host TPU experiment, is a TPU Slice of one VM. Each TPU version has different topology designs, from this documentation you can determine if a topology is single host or multi host.

A single host TPU experiment can be started using any AIchor operator, example from the manifest

...
spec:
# example with jax. Could be any other AIchor operator
operator: jax
...
types:
Worker:
# increasing the number Worker will request more single host TPU slices.
# As multi-slice is not yet supported, it will give you multiple independent single host TPU Slices.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: tpu-v5-lite-podslice
topology: 2x2
tpuChipsCount: 4

In the example above a 2x2 TPU Slice on the TPU v5e requests 2x2/4=1 VM.

Start a multi host TPU experiment with kuberay

For bigger TPU slices requiring multiple hosts on a single TPU slice, kuberay operator must be used, other operator won't support multi host TPU slice.

...
spec:
# must use kuberay
operator: kuberay
...
types:
Head:
resources:
cpus: 4
ramRatio: 2

Workers:
- name: tpu-pool
# When requesting a TPU Slice, the number of VMs is implied in the given topology.
# For example this manifest will request one 4x4 TPU Slice with 4 chips per VMs, 4x4/4=4 VMs.
# If count is 2, then it will request two 4x4 TPU Slices.
# Multi-slice is not yet supported.
count: 1
resources:
cpus: 90
ramRatio: 2
accelerators:
tpu:
type: tpu-v5-lite-podslice
topology: 4x4
tpuChipsCount: 4

Example of code using this 4x4 v5e TPU slice

import ray
import jax
import os

@ray.remote(resources={"TPU": 4})
def tpu_cores(index: int):
device_count = jax.device_count()
return f"Index [{index}]: Global TPU Count={device_count}"

if __name__ == "__main__":
ray.init(address=os.environ.get("RAY_ADDRESS", "auto"), log_to_driver=True)

num_workers = 4 # number of pods in 4x4 topology on v5e
print(f"starting {num_workers} remote functions")

result = [tpu_cores.remote(i) for i in range(num_workers)]
print(ray.get(result))

print("Done! exiting")
ray.shutdown()

This should print

starting 4 remote functions
['Index [0]: Global TPU Count=16', 'Index [1]: Global TPU Count=16', 'Index [2]: Global TPU Count=16', 'Index [3]: Global TPU Count=16']
Done! exiting

Each worker has 4 local TPU chips and can access the 16 chips globally.

Spot instances on GKE

If the quotas are for preemptible TPU, the experiment must request spot instances:

...
spec:
operator: kuberay
...
types:
Head: {}
Workers:
- name: tpu-pool
count: 1
resources:
cpus: 90
ramRatio: 2

# node selector to request spot instances on GKE.
extraSelectors:
cloud.google.com/gke-spot: "true"
cloud.google.com/gke-provisioning: spot
# on GKE, a toleration also has to be passed
extraTolerations:
- key: "cloud.google.com/gke-spot"
operator: "Equal"
value: "true"
effect: "NoSchedule"

accelerators:
tpu:
type: tpu-v5-lite-podslice
topology: 2x4
tpuChipsCount: 4