How to split a large model on Colab TPUs

Published

February 22, 2023

Google Colab has been an amazing tool for trying out new ML ideas. Recently, I’ve been checking out Colab’s TPU accelerator offering with the JAX platform and found out some really interesting stuff.

Colab offers 64 GiB of high-bandwidth memory with a TPUv2 version – a considerable jump in memory offering compared to Colab’s GPU instances. (More details about different TPU versions is available at Google Cloud.) Let’s see if we can run the full stable diffusion model by Runway with the HuggingFace backbone using the provided Jax + TPU support.

Getting started

HuggingFace offers a convenient starting point for Jax + TPU:

HuggingFace starter code
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

dtype = jnp.bfloat16

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=dtype, from_pt=True
)
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)

p_params = replicate(params, devices=jax.devices()) # replicate adds a leading (1, ...) on each tensor
prompt_ids = shard(prompt_ids)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
  
image_grid(images, 2, 4)

It uses the replicate and shard utilities of Flax to make 8 copies of the model (one for each TPU device) and parallelize 8 instances of the forward pass. This fits into memory on the smaller, bfloat16 branch where the pipeline is initialized with revision="bf16" and when the safety checker is turned off.

However, this block of code runs out of TPU memory for either one of revision="bf16" or revision="flax".

XlaRuntimeError: RESOURCE_EXHAUSTED: Could not allocate 12582912 bytes in memory 0x0x0_HBM0; 12517376 bytes allocatable, 18661376 bytes available

The only way to get results is to reduce the total allocation size by disabling the safety checker:

Output after setting safety_checker=None
You have passed `None` for safety_checker to disable its functionality in <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'>. Note that this might lead to problems when using <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> and is not recommended.
Some weights of the model checkpoint at stable-diffusion-v1-5/text_encoder were not used when initializing FlaxCLIPTextModel: {('text_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at [https://github.com/huggingface/diffusers/pull/254](https://github.com/huggingface/diffusers/pull/254) .

Stable diffusion output

Since the safety checker runs separately from the U-Net, we could load the safety checker on the next TPU core with a small modification in the pipeline. This solves the problem for our specific case but is not generalizable for models that span more than one TPU device. Can we figure out a generalizable way of spreading tensors over multiple TPU devices and still support matrix operations across device borders?

To motivate this venture, we should keep in mind Colab is offering us a total 64 GiB of high-bandwidth memory spread over 8 devices (or 4 cores). The ability to split a model into multiple devices without making assumptions about the underlying deep learning architecture can be a nifty way to play around with large models with minimal effort.

JAX support for cross-device programming

JAX tells us if we have two tensors between devices A and B, matrix multiplication fails.

key = jax.random.PRNGKey(0)
a = jax.random.uniform(key, shape=(3,5))
b = jax.random.uniform(key, shape=(5,3))
a = jax.device_put(a, device=jax.devices()[0])
b = jax.device_put(b, device=jax.devices()[1])
a @ b
ValueError: primitive arguments must be colocated on the same device, got TPU_0(host=0,(0,0,0,0)), TPU_1(host=0,(0,0,0,1))

The most performance-conscious way to spread a matrix multiplication over two devices is to use JAX’s pjit library following its detailed official walkthrough. pjit, as of writing this post in February 2023, is an experimental library that optimizes your code with just-in-time compilation and works with a device mesh definition. This device mesh allows the user to spread selected steps of the computation graph across device slices with APIs like jax.lax.with_sharding_constraint. An even better tutorial shows clear diagrams that show how your matrix multiplication will be sharded for the constraints you impose. The paper mentioned in the tutorial has even provided an optimal sharding spec for the specific architecture.

Unfortunately, Colab seems to run a legacy version of TPUv2 which does not have pjit support. This is also a warning at the top of the aforementioned walkthrough. Moreover, producing a PartitionSpec and adding jax.lax.with_sharding_constraint to individual steps of the computation graph is something that needs to be done manually and correctly for each different deep learning architecture, which wouldn’t make our solution generalizable. Fortunately, there’s one more thing to try which we find out after a bit of digging into the JAX API.

Although jax.Array is the default the return type of most JAX operations, jax.device_put_sharded and jax.device_put_replicated return a special subtype called ShardedDeviceArray, and this particular type allows operations between devices. Let’s revisit our previous matrix multiplication a @ b, this time using jax.device_put_sharded:

key = jax.random.PRNGKey(0)
a = jax.random.uniform(key, shape=(3,5))
b = jax.random.uniform(key, shape=(5,3))
a = jax.device_put_sharded([a], jax.devices()[0:1])
b = jax.device_put_sharded([b], jax.devices()[1:2])
a[0] @ b[0]
DeviceArray([[2.1039276 , 1.9485016 , 1.3808594 ],
             [1.0550652 , 0.8924408 , 0.4976406 ],
             [1.4164448 , 1.3743296 , 0.80164933]], dtype=float32)

So converting everything to this data type can be used to perform operations across devices. Neat!

One thing to note is that jax.device_put_sharded performs sharding across the first axis so we needed to add a leading (1, ...) to each array dimension in order to put whole tensors into a single device. If needed, we can zero-index out this extra dimension and we can also easily convert all tensors in a parameter pytree params:

jax.tree_util.tree_map(
    lambda x: jax.device_put_sharded([x], devices=[random.choice(jax.devices())]),
    params)

Don’t worry, random.choice is there for dramatic effect! (though, for model sizes that do not push to the TPU limit it may work) Instead of randomly choosing the allocation device, it makes more intuitive sense to allocate related tensors physically close to each other by developing a clock algorithm that moves the target device pointer to the neighbor when we get close to capacity in the current device.

Putting it all together

A small performance upgrade is needed on the pipeline side to make it work: while reading the PyTorch model off the disk and converting it into JAX, there’s a brief loading region where PyTorch and JAX weights are both referenced in RAM. This duplication of weights exhausts CPU memory before we can move things to the TPU. The quick solution is to keep overwriting the same state variable name during conversion so that only one copy of the weights is referenced. The garbage collector does the rest of the heavy-lifting:

Around line 406 of modeling_flax_utils.py, change

    # Step 1: Get the pytorch file
    pytorch_model_file = load_state_dict(model_file)
    
    # Step 2: Convert the weights
    state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)

with 

    # Step 1: Get the pytorch file
    state = load_state_dict(model_file)
    
    # Step 2: Convert the weights
    state = convert_pytorch_state_dict_to_flax(state, model)

For personal convenience, I cloned the weight repo (runwayml/stable-diffusion-v1-5) to my local relative Google Drive path so my pipeline calls read FlaxStableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5", ...) instead.

Most importantly, let’s remove the model replication from the HuggingFace starter. After all, the goal is to get a large model to generate a single output as opposed to get a small model to produce multiple outputs.

With all this in mind, here’s the error case that puts everything on one TPU device:

Code and output
# single image, non-jitted and everything on single TPU with safety checker

# workaround for pip install -e not working:
if "/content/drive/MyDrive/projects/diffusers/src" not in sys.path:
  sys.path.append("/content/drive/MyDrive/projects/diffusers/src")
from diffusers import FlaxStableDiffusionPipeline

dtype = jnp.bfloat16

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5", revision="flax", dtype=dtype, from_pt=True
)
print("pipeline loaded")
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * 1
prompt_ids = pipeline.prepare_inputs(prompt)

p_params = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices()[2]), params)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, 1)

images = pipeline(prompt_ids, p_params, rng, jit=False)[0]
images = pipeline.numpy_to_pil(images)

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

image_grid(images, 1, 1)
Some weights of the model checkpoint at stable-diffusion-v1-5/text_encoder were not used when initializing FlaxCLIPTextModel: {('text_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at stable-diffusion-v1-5/safety_checker were not used when initializing FlaxStableDiffusionSafetyChecker: {('vision_model', 'vision_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
pipeline loaded
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Dependency failed: Dependency 
failed: Timed out while waiting for dependency 671371:66613 to be resolved. This is usually due to a server restart
and a stale client. Try restarting your client.

If we keep everything the same but distribute the model instead, we are in good shape:

Code and output
# single image, non-jitted, greedily sharded with safety checker

# workaround for pip install -e not working:
if "/content/drive/MyDrive/projects/diffusers/src" not in sys.path:
  sys.path.append("/content/drive/MyDrive/projects/diffusers/src")
from diffusers import FlaxStableDiffusionPipeline

dtype = jnp.bfloat16

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5", revision="flax", dtype=dtype, from_pt=True,
)
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * 1
prompt_ids = pipeline.prepare_inputs(prompt)

capacities = [3*1e9 for _ in range(len(jax.devices()[2:8]))]
distributor = TensorDistributor(devices=jax.devices()[2:8], capacities=capacities)
p_params = distributor.greedily_distribute_tensors(params, squeeze_first_axis=True) # the first axis aids in jitting

def create_key(seed=0):
    return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, 1)

images = pipeline(prompt_ids, p_params, rng, jit=False)[0]
images = pipeline.numpy_to_pil(images)

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

image_grid(images, 1, 1)

Output:

Some weights of the model checkpoint at stable-diffusion-v1-5/safety_checker were not used when initializing FlaxStableDiffusionSafetyChecker: {('vision_model', 'vision_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at stable-diffusion-v1-5/text_encoder were not used when initializing FlaxCLIPTextModel: {('text_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

Stable diffusion output

Here’s the definition of TensorDistributor, which is a wrapper for the tree_map shown earlier but it keeps a tally of TPU capacities to figure out where to allocate:

Code
class TensorDistributor:
  def __init__(self, devices, capacities):
    assert len(devices) >= 1, "At least one device is needed."
    assert len(devices) == len(capacities), "Devices and capacities must match."
    self.devices, self.capacities = devices, capacities
    self.idx = len(devices)-1 # index of current allocation device
  
  @staticmethod
  def _move(tensor, device, squeeze_first_axis=False):
    result = jax.device_put_sharded([tensor], devices=[device])
    return result[0] if squeeze_first_axis else result

  @staticmethod
  def randomly_distribute_tensors(
      params: frozen_dict.FrozenDict,
      squeeze_first_axis=False,
      devices=jax.devices()
  ):
      """Spreads all tensors in `params` across jax.devices() randomly
      
      Args:
          params (dict): Params dict for the network.
      Returns:
          new_params: A dictionary identical to params in structure, 
            except tensors are distributed to different devices.
      """

      map = jax.tree_util.tree_map(
              lambda x: TensorDistributor._move(x, random.choice(devices), 
                                                squeeze_first_axis),
              params
            )
      return map

  def _move_greedy(self, tensor, squeeze_first_axis=False):
    tensor_size = tensor.nbytes
    if self.capacities[self.idx] >= tensor_size:
      self.capacities[self.idx] -= tensor_size
      return TensorDistributor._move(tensor, self.devices[self.idx], 
                                     squeeze_first_axis)

    # find a new device starting from the current device, allowing wrap-around
    for i in range(self.idx-1, self.idx-1-len(self.devices), -1):
      if i < 0:
        i += len(self.devices) 
      if self.capacities[i] >= tensor_size:
        self.idx = i
        self.capacities[self.idx] -= tensor_size
        return TensorDistributor._move(tensor, self.devices[self.idx],
                                       squeeze_first_axis)
    else:
      raise Exception((f"Failed to allocate {tensor.nbytes} bytes because the "
                       f"devices have {self.capacities} bytes free."))
  
  def greedily_distribute_tensors(self,
      params: frozen_dict.FrozenDict,
      squeeze_first_axis=False
  ):
      """Spreads all tensors in `params` across jax.devices() in device order, 
      respecting memory limits set forth by `self.capacity`.
      
      Args:
          params (dict): Params dict for the network.
      Returns:
          new_params: A dictionary identical to params in structure, except
            tensors are distributed to different devices.
      """
      
      map = jax.tree_util.tree_map(
              lambda x: self._move_greedy(x, squeeze_first_axis),
              params
            )
      return map
Disclaimer: Feel free to use TensorDistributor according to the MIT License. No warranties are implied.

We could have avoided defining capacities if we had a function that returns the free memory in bytes i.e. something like torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0) but for TPUs in JAX. However, the only way I could find to check free memory usage in JAX is to use a profiler which only outputs prof files. Not a big deal since hard-coding capacities seems to be sufficient for the purpose, and also gives the user intentional control over limiting memory consumption per TPU.

Hope you find this interesting and useful for supercharging your DL projects using TPUs!