Currently, despite their documentation saying otherwise, the Linux TheRock builds of PyTorch do not currently support Flash Attention for gfx1151
(Strix Halo). I’ve filed an issue so hopefully that will be fixed upstream: [Issue]: PyTorch Flash Attention with gfx1151 · Issue #1364 · ROCm/TheRock · GitHub
But in the meantime, and potentially useful if you’re looking to build your own PyTorch for other reasons, I’ve created some build script which includes multiple patches for building PyTorch w/ aotriton (for basic Flash Attention) and w/ GLOO distributed support (for being able to launch vLLM) leveraging TheRock CI scripts (better than previous standalone efforts I think from a maintainability, and staying in sync w/ ROCm updates perspective): strix-halo-testing/torch-therock at main · lhl/strix-halo-testing · GitHub
The performance difference on fwd and bwd pass attention is… not insignificant. In my basic test, it is 20-30X faster on fwd pass and 8-20X faster on bwd pass. Before:
╔═════════════════════════════════════════════════════════════════════════════════════════╗
║ Testing XLarge: B=16, H=16, S=4096, D=64 ║
╚═════════════════════════════════════════════════════════════════════════════════════════╝
Estimated memory per QKV tensor: 0.12 GB
Total QKV memory: 0.38 GB
+--------------+----------------+-------------------+----------------+-------------------+
| Operation | FW Time (ms) | FW FLOPS (TF/s) | BW Time (ms) | BW FLOPS (TF/s) |
+==============+================+===================+================+===================+
| Causal FA2 | 1950.32 | 0.28 | 2339.34 | 0.59 |
+--------------+----------------+-------------------+----------------+-------------------+
| Regular SDPA | 1498.6 | 0.37 | 2351.16 | 0.58 |
+--------------+----------------+-------------------+----------------+-------------------+
After:
╔═════════════════════════════════════════════════════════════════════════════════════════╗
║ Testing XLarge: B=16, H=16, S=4096, D=64 ║
╚═════════════════════════════════════════════════════════════════════════════════════════╝
Estimated memory per QKV tensor: 0.12 GB
Total QKV memory: 0.38 GB
+--------------+----------------+-------------------+----------------+-------------------+
| Operation | FW Time (ms) | FW FLOPS (TF/s) | BW Time (ms) | BW FLOPS (TF/s) |
+==============+================+===================+================+===================+
| Causal FA2 | 68.5042 | 8.03 | 115.386 | 11.91 |
+--------------+----------------+-------------------+----------------+-------------------+
| Regular SDPA | 75.1694 | 7.31 | 279.074 | 4.92 |
+--------------+----------------+-------------------+----------------+-------------------+
You can actually compare vs results from a few months ago, the numbers are getting a bit better: [Issue]: Is there a ROCm version that supports gfx1151? · Issue #4499 · ROCm/ROCm · GitHub (Of course, since this still needs to be built by internet randos and there’s been little/no progress for upstream over 3 months, it’s a glass half-full/half-empty thing.)
The other thing I finally got around to is that with my own build of torch and some elbow grease, I was able to get vLLM running. I have some (not perfect) build scripts for vLLM: strix-halo-testing/vllm at main · lhl/strix-halo-testing · GitHub - these are WIP, but I’ll share them here in case people are interested, I may or may not polish them off later but the build script roughly documents the gymnastics you currently need to do. A couple notes:
- Current TheRock PyTorch builds don’t work w/ vLLM and cause GPU hangs for me at building the CUDA graph or if you skip it w/ eager mode, on warmup. If you want it to work, you probably need to build your own torch.
- amdsmi appears to be incompatible w/ either TheRock ROCm or gfx1151 in general - it causes segfaults and it was easiest for me just to have it removed and patch it out of vLLM. Maybe someone at AMD should fix that? (crazy thought, I know)
- My tests with the standard Llama 2 7B model worked, but gpt-oss complained about missing triton-kernels. Of course I have a custom triton and aotriton built, so…
- Performance is as expected, not great, but especially bad for W8A8-INT8, which was actually a surprise to me (RDNA3.5 has INT8 support, on Nvidia, W8A8-INT8 is by far the fastest quant, although for older cards, that’s mostly due to the Marlin kernels). F16 and GPTQ are OK, but neither beat llama.cpp even at c=16 where you’d expect vLLM to have the biggest advantage.
You can visit my repo for more numbers, but for example, at concurrency=16, here’s what FP16 results looked like (using vllm bench and ShareGPT):
Backend | Model/Quant | Duration (s) | Req/s | Output Tok/s | Total Tok/s | Mean TTFT (ms) | Mean TPOT (ms) |
---|---|---|---|---|---|---|---|
vLLM | meta-llama/Llama-2-7b-chat-hf (FP16) | 155.97 | 0.41 | 72.17 | 186.35 | 871.62 | 162.64 |
llama.cpp | Llama-2-7b-chat-hf.f16.gguf | 79.95 | 0.45 | 66.85 | 83.60 | 497.90 | 175.86 |
And here’s what Q4 looks like at c=16:
Backend | Model/Quant | Duration (s) | Req/s | Output Tok/s | Total Tok/s | Mean TTFT (ms) | Mean TPOT (ms) |
---|---|---|---|---|---|---|---|
vLLM | TheBloke/Llama-2-7B-Chat-GPTQ | 85.72 | 0.75 | 148.91 | 356.67 | 787.02 | 90.31 |
llama.cpp | llama-2-7b.Q4_K_M.gguf | 32.85 | 1.10 | 178.68 | 219.44 | 381.53 | 64.94 |
vLLM would be useful for compatibility with new models, training/common ML frameworks, etc, but my experience trying to run gpt-oss makes me think (as usual) it won’t be so easy.
Anyway, I’ll start a thread here and maybe revisit at some point in the future, but have lots of real work that I should be doing and also have once again reached the point of asking why a $250B corporation selling an “AI” workstation chip (that’s been in the market for 6+ months now) doesn’t care enough to get even the basics working. (I guess maybe why the other company that’ll charge an arm and a leg but give you something works is a $4T market cap company. But at this rate, the Chinese will deliver a real alternative for AI hardware long before AMD does.)