Jae-Kyung Cho Being unique is better than being perfect

MLX: A Machine Learning Framework for Apple Silicon - 04. LLM inference example

LLM Inference with MLX

Let’s check how much of a speed gain we can get from LLM inference using the MLX library. This includes the work of implementing a Transformer-based LLaMA model from scratch.

A sudden personal opinion

Having to implement the LLaMA model from scratch… this is where I see a downside. In the case of PyTorch, its tight integration with the amazing platform that is huggingface really stands out. With just 4-5 lines of code, you can download models from the huggingface hub and run inference on them. Well, MLX is a library that just came out not long ago, so there’s plenty of room for improvement. But rather than aiming to be a framework at the same level as torch, wouldn’t it be better to head in the direction of connecting code that’s already implemented in torch to mlx (?). Otherwise, having to manually build the same model code every time a PyTorch architecture comes out….!! Just thinking about it sounds so inconvenient.

Building the LLaMA model

First, let’s import the relevant modules. One thing to be careful about: make sure to judge the size of the LLMs against your MacBook’s RAM size. My M1 MacBook Pro has 16G of memory, so any model larger than 7B won’t fit in memory.

import mlx.core as mx
import mlx.nn as nn
import math
import gc

The LLaMA architecture is built by stacking blocks like the one shown in the figure below. It uses a technique called pre-normalization, which performs RMSNorm before attention. image

It also uses SwiGLU as the activation function. SwiGLU is an activation function that combines the Swish activation function and GLU. The mathematical formulas for Swish activation, GLU, and SwiGLU are as follows. \(Swish(x) = x\sigma (\beta x)\) \(GLU(x,W,V,b,c)=\sigma(xW+b)\otimes(xV+c)\) \(SwiGLU(x,W,V,b,c,\beta)=Swish_\beta(xW+b)\otimes(xV+c)\)

The formulas look difficult, but when beta=1 it can be implemented simply with linear layers.

So let’s define a total of 5 modules: LlamaAttention, LlamaMLPLayer, LlamaEncoderLayer, LlamaModel, and LlamaForCausalLM. The code is the same implementation as Huggingface llama. The module variable names are also set identically to the Huggingface llama model. This is to make it possible to download the huggingface weights and apply them directly.

Llama Attention

class LlamaAttention(nn.Module):
    def __init__(self, 
                 dims: int,
                 num_heads: int):
        super().__init__()
        
        self.num_heads = num_heads
        
        self.rope = nn.RoPE(dims // num_heads, traditional=True)
        self.q_proj = nn.Linear(dims, dims, bias=False)
        self.k_proj = nn.Linear(dims, dims, bias=False)
        self.v_proj = nn.Linear(dims, dims, bias=False)
        self.o_proj = nn.Linear(dims, dims, bias=False)
        
    def __call__(self, x, mask=None, cache=None):
        queries = self.q_proj(x)
        keys = self.k_proj(x)
        values = self.v_proj(x)
        
        # Shape values
        num_heads = self.num_heads
        B, L, D = queries.shape
        
        # Shape preprocessing (B, num_heads, L, D // num_heads)
        queries = queries.reshape(B, L, num_heads, -1).transpose(0,2,1,3)
        keys = keys.reshape(B, L, num_heads, -1).transpose(0,2,1,3)
        values = values.reshape(B, L, num_heads, -1).transpose(0,2,1,3)
        
        # query[0,0,0,0] -> -0.4308
        
        # RoPE processing
        if cache is not None:
            # axis = 2 -> Length (sequence position)
            key_cache, value_cache = cache
            queries = self.rope(queries, offset=key_cache.shape[2])
            keys = self.rope(keys, offset=key_cache.shape[2])
            keys = mx.concatenate([key_cache, keys], axis=2)
            values = mx.concatenate([value_cache, values], axis=2)
        else:
            queries = self.rope(queries)
            keys = self.rope(keys)
            
        # query[0,0,0,-1] = 0.6716
            
        scale = math.sqrt(1 / queries.shape[-1])
        # queries is (B, num_heads, L, D // num_heads)
        # keys is (B, num_heads, D//num_heads, L)
        # result is (B,num_heads, L, L)
        scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
        
        if mask is not None:
            scores = scores + mask
        scores = mx.softmax(scores, axis=-1)
        # score @ values is  (B, num_heads, L, D // num_heads)
        # after transpose it's (B, L, num_heads, D // num_heads)
        # value_hat is (B, L, D)
        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
        
        # values_hat[0,0,0] = -0.0005
        
        # Return key, value (for the cache)
        return self.o_proj(values_hat), (keys, values)

LLaMA MLP

This is the module where the swiGLU activation is implemented.

class LlamaMLPLayer(nn.Module):
    def __init__(self,
                 dims: int,
                 mlp_dims: int):
        super().__init__()
        
        self.gate_proj = nn.Linear(dims, mlp_dims, bias=False)
        self.up_proj = nn.Linear(dims, mlp_dims, bias=False)
        self.down_proj = nn.Linear(mlp_dims, dims, bias=False)
        
    def __call__(self, x):
        a = self.gate_proj(x)
        b = self.up_proj(x)
        x = a * mx.sigmoid(a) * b # SwiGLU (the version using beta=1 in Swish)
        x = self.down_proj(x)
        
        return x

LLaMA encoder

class LlamaEncoderLayer(nn.Module):
    def __init__(self,
                 dims: int,
                 mlp_dims: int,
                 num_heads: int):
        super().__init__()
        
        self.self_attn = LlamaAttention(dims, num_heads)
        self.input_layernorm = nn.RMSNorm(dims)
        self.post_attention_layernorm = nn.RMSNorm(dims)
        self.mlp = LlamaMLPLayer(dims, mlp_dims)
        
    def __call__(self, x, mask=None, cache=None):
        y = self.input_layernorm(x)
        y, cache = self.self_attn(y, mask=mask, cache=cache)
        x = x + y
        
        y = self.post_attention_layernorm(x)
        y = self.mlp(y)
        x = x + y
        
        return x, cache

LLaMA model

class Llama(nn.Module):
    def __init__(
        self, 
        num_layers: int,
        vocab_size: int,
        dims: int,
        mlp_dims: int,
        num_heads: int
    ):
        super().__init__()
        
        self.embed_tokens = nn.Embedding(vocab_size, dims)
        self.layers = [
            LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
        ]
        self.norm = nn.RMSNorm(dims)
        
    def __call__(self, x, cache=None, return_cache=False):
        if cache:
            assert len(self.layers) == len(cache), "Length of cache must be equal to number of layers"
            mask = None
        else:
            mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
            mask = mask.astype(self.embed_tokens.weight.dtype)
            cache = [None] * len(self.layers)
        
        x = self.embed_tokens(x)
            
        for i, layer in enumerate(self.layers):
            x, c = layer(x, mask, cache=cache[i])
            if return_cache:
                cache[i] = c
        x = self.norm(x)
        
        if return_cache:
            return x, cache
        
        return x

LLaMA causal LM

class LlamaForCausalLM(nn.Module):
    def __init__(
        self, 
        num_layers: int,
        vocab_size: int,
        dims: int,
        mlp_dims: int,
        num_heads: int
    ):
        super().__init__()
        self.model = Llama(num_layers,
                            vocab_size,
                            dims,
                            mlp_dims,
                            num_heads)
        self.lm_head = nn.Linear(dims, vocab_size, bias=False)
        
    def __call__(self, x):
        x = self.model(x)
        
        return self.lm_head(x)
        
    def generate(self, x, temp=1.0):
        x, cache = self.model(x, return_cache=True)
        # Use only the last token (the previous ones are the prompt)
        y = self.lm_head(x[:, -1])
        y = mx.random.categorical(y * (1/temp))
        
        # Due to lazy evaluation, even if we yield, the computation isn't necessarily performed.
        yield y
        
        while True:
            x = y[:, None]
            x, cache = self.model(x, cache=cache, return_cache=True)
            y = self.lm_head(x[:, -1])
            y = mx.random.categorical(y * (1/temp))
            
            yield y

Converting original Llama weight

So let’s download an appropriate llama architecture model from the Huggingface hub, convert it to MLX weights, and then load it into the MLX model we built ourselves. Considering memory issues, I selected the 1.3b model princeton-nlp/Sheared-LLaMA-1.3B.

map_torch_to_mlx is for converting layer names when they are set differently. Since the model above has exactly the same variable names and structure as huggingface llama, no separate key name changes are needed. In the original code, MLX LLM inference, the structure and variable names were different, so conversion work was needed.

from itertools import starmap

import numpy as np
import torch

def map_torch_to_mlx(key, value):
    # if "tok_embedding" in key:
    #     key = "embedding.weight"

    # elif "norm" in key:
    #     key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")

    # elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
    #     key = key.replace("wq", "query_proj")
    #     key = key.replace("wk", "key_proj")
    #     key = key.replace("wv", "value_proj")
    #     key = key.replace("wo", "out_proj")

    # elif "w1" in key or "w2" in key or "w3" in key:
    #     # The FFN is a separate submodule in PyTorch
    #     key = key.replace("feed_forward.w1", "linear1")
    #     key = key.replace("feed_forward.w3", "linear2")
    #     key = key.replace("feed_forward.w2", "linear3")

    # elif "output" in key:
    #     key = key.replace("output", "out_proj")

    # elif "rope" in key:
    #     return None, None
    
    return key, value.numpy()

For now, there is still no way to load a huggingface model into MLX by name alone. Therefore, we use an approach where we convert the huggingface torch weights into numpy arrays, save them to files, and then convert those numpy arrays into mlx arrays to load them. The MLX documentation states that this will be improved in the future.

Let’s create a folder called hf and save the weights.npz files inside it. Since we could waste memory uselessly on numpy arrays, we use garbage collection to make it use as little memory as possible.

from transformers import AutoModelForCausalLM
import os

hf_path = "princeton-nlp/Sheared-LLaMA-1.3B"
output_path = f"hf/{hf_path}"
output_file = output_path + "/{key}.npz"
if not os.path.exists(output_path):
    os.makedirs(output_path)

if not os.path.exists(output_path + "/DONE"):
    hf_model = AutoModelForCausalLM.from_pretrained(hf_path)
    hf_model.config.to_json_file(output_path + "/config.json")
    state = hf_model.state_dict()

    np.savez(
        output_file.format(key="weights"),
        **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
    )

    # Done indicate
    f = open(output_path + '/DONE', 'w')
    f.close()
    del hf_model
    
    gc.collect()

Using these files, we update the model weights. A function called mlx.utils.tree_unflatten is used here. From the npz file made of {key: value}, if the keys are structured like layers.2.attention.query_proj.weight, it converts them into

{"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}

this form. Converting into this form lets you directly update the corresponding weights through model.update.

from mlx.utils import tree_unflatten
import json

# https://huggingface.co/beomi/llama-2-ko-7b/blob/main/config.json
with open(output_path + "/config.json", 'r') as f:
    config = json.load(f)

model = LlamaForCausalLM(num_layers=config['num_hidden_layers'],
              vocab_size=config['vocab_size'], # tokenizer.vocab_size
              dims=config['hidden_size'], 
              mlp_dims=config['intermediate_size'], 
              num_heads=config['num_attention_heads'])
for np_file in os.listdir(output_path):
    if np_file.endswith('.npz'):
        data = tree_unflatten(list(mx.load(output_path + f"/{np_file}").items()))
        # if "model" in data.keys():
        #     if "layers" in data["model"].keys():
        #         data['model']['layers'].extend([{}] * (config['num_hidden_layers']-len(data['model']['layers'])))
        model.update(data)
        del data
        
gc.collect()
mx.eval(model.parameters())

LLaMA inference

All preparations are complete!! So let’s run inference. For the tokenizer, we use Huggingface’s tokenizer as-is.

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(hf_path)
from time import time

prompt = "My name is Julien and I like to"
max_token_length = 100

x = mx.array([tokenizer.encode(prompt)])
tokens = []

mx.eval(x)

tic = time()
for token in model.generate(x, temp=1.0):
    tokens.append(token)
    
    if len(tokens) == 1:
        mx.eval(token)
    
    if len(tokens) >= max_token_length:
        break

mx.eval(tokens)
toc = time()

s = tokenizer.decode([t.item() for t in tokens], skip_special_tokens=True)
print(s, flush=True)
print(f"Throughput (MLX): {len(tokens)/(toc-tic)} tokens/sec")
write. I write about... lots of things. Unmannerstood, Difficult & Unniverse
 nova the unicorn, which makes me confidently and I am unable to deal with everything from fiction stories along with lol elfing, podcasting from 1sthing jobsearching
com.Home about...Find backs.
i and bjigg@nak..
2tourn. Most dês.
Matt. I'
Throughput (MLX): 21.244782071635516 tokens/sec

We can confirm that inference works well. For throughput, we can see that we get about 20 tokens/sec.

Comparing throughput with Huggingface

So let’s compare with using Huggingface, i.e. PyTorch, to check how much of a speed gain we can get.

# the existing model must be unloaded due to lack of memory
del model
gc.collect()
734 ```python hf_model = AutoModelForCausalLM.from_pretrained(hf_path)

input_ids = tokenizer.encode(‘prompt’, return_tensors=’pt’) tic = time() output = hf_model.generate(input_ids, max_length=max_token_length) toc = time() s = tokenizer.decode(output[0], skip_special_tokens=True)

print(s, flush=True) print(f”Throughput (HF, cpu): {len(tokens)/(toc-tic)} tokens/sec”)

print(“\n” + “=”*20 + “\n”)

device = torch.device(“mps:0”) hf_model.to(device) input_ids = tokenizer.encode(‘prompt’, return_tensors=’pt’).to(device) tic = time() output = hf_model.generate(input_ids, max_length=max_token_length) toc = time() s = tokenizer.decode(output[0], skip_special_tokens=True)

print(s, flush=True) print(f”Throughput (HF, mps): {len(tokens)/(toc-tic)} tokens/sec”)

    promptly and efficiently.
    We are a family owned and operated business that has been in business for over 20 years. We are a full service company that can handle all of your needs. We are a full service company that can handle all of your needs.
    We are a full service company that can handle all of your needs. We are a full service company that can handle all of your needs.
    We are a full service company that can handle all of your needs. We
    Throughput (HF, cpu): 4.973824978994109 tokens/sec
    
    ====================
    
    promptly and efficiently.
    We are a family owned and operated business that has been in business for over 20 years. We are a full service company that can handle all of your needs. We are a full service company that can handle all of your needs.
    We are a full service company that can handle all of your needs. We are a full service company that can handle all of your needs.
    We are a full service company that can handle all of your needs. We
    Throughput (HF, mps): 9.700159370346634 tokens/sec

```python
# For memory
del hf_model
gc.collect()

We can see that when using the CPU, the throughput is about 5 tokens/sec, and when accelerated using MPS, about 10 tokens/sec. In other words, we were able to confirm that using MLX gives about a 2x speed improvement.

References

Comments