# Research: Whisper/CTranslate2 Cold Start on Blackwell (SM_121)

**Date:** 2026-02-28
**Context:** Audio pipeline on DGX Spark (GB10, SM_121) takes ~150s on first `transcribe()` after every container/process restart. Model load is ~10s. Subsequent inferences are 1-2s.

## Root Cause Analysis

### The 150s is PTX JIT Compilation — Confirmed

The delay is **not** cuBLAS algorithm selection or cuDNN autotuning. It is **CUDA PTX-to-SASS JIT compilation**.

**What happens:**

1. CTranslate2 (and PyTorch, pyannote, wav2vec2) ship with **PTX intermediate code** — not native SASS binaries for SM_121
2. The DGX Spark's GB10 GPU reports as **SM_121** (compute capability 12.1)
3. The `mekopa/whisperx-blackwell` container uses **architecture spoofing** — it patches `get_device_capability()` to return (9, 0) (Hopper) so SM_90 PTX code is accepted
4. The CUDA driver must **JIT-compile every unique PTX kernel to native SASS** on first use
5. The pipeline hits hundreds of unique kernels across Whisper (transformer attention, convolutions), pyannote (segmentation + embedding), and wav2vec2 (alignment)
6. Each kernel compilation takes milliseconds to seconds — but hundreds of them accumulate to ~150s

**Why CUDA_CACHE_PATH initially only helped ~6%:**

The NGC base image sets **`CUDA_FORCE_PTX_JIT=1`** — this tells the CUDA runtime to IGNORE all cached SASS and re-compile from PTX on every kernel launch. Even though our `run.sh` enabled `CUDA_CACHE_DISABLE=0` and the cache populated to 629 MB / 72 files, the runtime bypassed it entirely.

**The fix: `CUDA_FORCE_PTX_JIT=0`** in the docker run command. This lets the CUDA runtime check the JIT cache before compiling. Result: **276s → 8.1s** (first restart with warm cache) → **3.1s** (subsequent restarts).

The base image sets `=1` presumably because SM_121 has no matching SASS in any library, so forcing PTX ensures predictable behavior. But it defeats the entire purpose of the JIT cache.

### SM_121 Is a Unique Architecture

Key insight from [Backend.AI's deep analysis](https://www.backend.ai/blog/2026-02-is-dgx-spark-actually-a-blackwell): SM_121 is **not** datacenter Blackwell (SM_100). It uses **Ampere-era `mma.sync` instructions** rather than SM_100's `tcgen05`. The instruction set is closer to SM_80 (Ampere) than to SM_100 (datacenter Blackwell).

| Feature | SM_100 (datacenter) | SM_121 (GB10) |
|---------|--------------------|--------------|
| Tensor instruction | tcgen05 (5th gen) | mma.sync (Ampere-era) |
| Tensor memory (TMEM) | 256 KB/SM | None |
| Shared memory | 228 KB | 128 KB |

This means SM_90 PTX is the correct fallback — SM_121 is binary-compatible with SM_90 but there are **no prebuilt SASS binaries for SM_121** in any current ML library.

### CTranslate2 Specific Issues on Blackwell

Per [OpenNMT/CTranslate2#1865](https://github.com/OpenNMT/CTranslate2/issues/1865):
- **INT8 compute types fail on SM_120/121** — cuBLAS INT8 kernels fail when matrix dimensions aren't divisible by 4
- Fix merged in PR #1937: auto-disables INT8 for SM_120+, forces float16
- Your pipeline already uses `compute_type="float16"`, so this doesn't affect you directly
- CTranslate2 wheels are built with `CUDA_ARCH_LIST="Common"` — which means **SM_70, SM_75, SM_80, SM_86 + PTX** (no SM_90, no SM_120/121 native binaries)
- Every kernel must be JIT-compiled from PTX on SM_121

---

## Solutions: Ordered by Effort-to-Impact Ratio

### 1. ✅ DONE: Fix CUDA JIT Cache — `CUDA_FORCE_PTX_JIT=0`
**Effort:** 5 minutes | **Impact:** 276s → 3-8s on restarts (97.7% reduction) | **Risk:** None

**Root cause:** The NGC base image sets `CUDA_FORCE_PTX_JIT=1`, which forces the CUDA runtime to re-compile ALL kernels from PTX on every launch, even when cached SASS binaries exist in the JIT cache.

**Fix:** Add `-e "CUDA_FORCE_PTX_JIT=0"` to the `docker run` command in `run.sh`.

**How to verify:**
```bash
# Check CUDA cache is populated (should be 600-700 MB):
docker exec her-os-audio du -sh /cuda-cache/

# Check warmup time in logs:
docker logs her-os-audio 2>&1 | grep -E "Warmup|warmup"
# Expected: "Warmup: total 3.1s" (with warm cache)
#           "Warmup: total 8.1s" (first restart after cache populated)

# Force cold cache (only needed after image rebuild):
rm -rf ~/.local/share/her-os-audio/cuda-cache/*
# Then restart — first run takes ~276s to re-populate cache
```

**Environment variables in run.sh:**
```bash
-e "CUDA_CACHE_DISABLE=0"         # NGC disables cache by default
-e "CUDA_CACHE_PATH=/cuda-cache"  # persistent volume mount
-e "CUDA_CACHE_MAXSIZE=2147483648" # 2 GB limit
-e "CUDA_FORCE_PTX_JIT=0"         # USE the cache (base image sets =1!)
```

### 2. SHORT-TERM: Model-in-Sidecar Architecture (Hot Code Reload Without GPU Restart)
**Effort:** 2-4 hours | **Impact:** Zero cold start during code changes | **Risk:** Low

Currently, editing `pipeline.py` or `main.py` requires a container restart, which triggers the full 150s warmup. Since you bind-mount these files as `:ro`, you're paying this cost on every code change.

**Solution: Decouple model process from API process.**

**Option A: uvicorn-hmr (Hot Module Reload)**
```bash
pip install uvicorn-hmr
# Replace CMD in Dockerfile:
CMD ["uvicorn-hmr", "main:app", "--host", "0.0.0.0", "--port", "9100"]
```
With uvicorn-hmr, the main process never restarts — only changed Python modules are reloaded. The GPU model stays loaded. This eliminates cold start for code-only changes.

**Caveat:** Model objects in module-level globals (`pipeline = AudioPipeline()`) would need to survive reloads. May need to move model loading to a separate module that isn't reloaded.

**Option B: Separate model server process**
```python
# model_server.py — long-lived, never restarts
# Loads models once, exposes internal API on unix socket or localhost:9102
# API server (main.py) forwards to model_server.py via HTTP/gRPC
```
This is more work but gives clean separation: restart `main.py` all you want, models stay warm in `model_server.py`.

### 3. MEDIUM-TERM: Build CTranslate2 From Source with SM_121 Native Binaries
**Effort:** 4-8 hours | **Impact:** Eliminate PTX JIT entirely (150s → ~0s) | **Risk:** Medium (build may fail)

The root cause is that no prebuilt wheel includes SM_121 (or even SM_90) SASS binaries. Building from source with the right architecture flags would embed native binaries.

```bash
# Inside DGX Spark (has CUDA 13.0 + SM_121 GPU):
git clone --recursive https://github.com/OpenNMT/CTranslate2.git
cd CTranslate2

# Build with SM_121 native + SM_90 PTX fallback:
cmake -B build \
  -DWITH_CUDA=ON \
  -DCUDA_ARCH_LIST="9.0;12.1" \
  -DCUDA_NVCC_FLAGS="-gencode arch=compute_90,code=sm_90 -gencode arch=compute_121,code=sm_121 -gencode arch=compute_90,code=compute_90" \
  -DWITH_CUDNN=ON \
  -DWITH_MKL=OFF \
  -DCMAKE_BUILD_TYPE=Release

cmake --build build -j$(nproc)
cd python && pip install .
```

**Why this works:** With native `sm_121` SASS binaries embedded in the library, CUDA skips JIT compilation entirely — the binary code runs directly on the GPU.

**Risks:**
- CTranslate2's Flash Attention kernels are only written for SM_80 — may need `sm_80` in the list too
- Build requires CUDA 13.0+ toolkit (DGX Spark has this)
- SM_121 support in nvcc is CUDA 13.0+ only
- May expose untested code paths

### 4. MEDIUM-TERM: Switch to Parakeet-TDT via NeMo (NVIDIA's Native ASR)
**Effort:** 1-2 days | **Impact:** Likely fast startup + better accuracy + native Blackwell support | **Risk:** Medium (different pipeline)

NVIDIA's [Parakeet-TDT 1.1B](https://huggingface.co/nvidia/parakeet-tdt-1.1b) is purpose-built for fast inference and has confirmed DGX Spark support via Riva NIM.

**Advantages over Whisper:**
- **RTFx ~2000+** vs Whisper's ~60x — dramatically faster inference
- **600M params** (0.6B) vs 1.55B (large-v3) — less VRAM
- **NeMo/Riva optimized for Blackwell** — native TensorRT engines, no PTX JIT
- **Better WER** — Parakeet v2 achieves 6.05% WER (competitive with Whisper large-v3)
- **Streaming support** — Parakeet TDT supports streaming, Whisper doesn't
- **Punctuation + capitalization** built in (Whisper needs post-processing)

**Limitations:**
- **English only** (Parakeet 0.6B TDT v3 is multilingual but limited languages)
- No word-level alignment out of the box (WhisperX gives this via wav2vec2)
- Different API — would need pipeline.py refactoring
- Sarvam auto-routing for Kannada would still be needed

**DGX Spark status:**
- Parakeet 1.1B CTC and RNNT: **confirmed working** on DGX Spark via NIM
- Parakeet 0.6B TDT: **not supported** on Blackwell via NIM
- Direct NeMo usage: **working** per [community benchmarks](https://forums.developer.nvidia.com/t/running-parakeet-speech-to-text-on-spark/356353) — 282x RTFx on GB10

**Quick test:**
```python
import nemo.collections.asr as nemo_asr
model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v3")
# Time the first transcribe() vs second transcribe()
result = model.transcribe(["test.wav"])
```

### 5. MEDIUM-TERM: whisper.cpp with CUDA on DGX Spark
**Effort:** 1 day | **Impact:** Potentially fast cold start (compiled C++ binary, no Python overhead) | **Risk:** Medium

[whisper.cpp](https://github.com/ggml-org/whisper.cpp) is a pure C/C++ Whisper implementation using GGML. It compiles CUDA kernels at **build time**, not runtime — eliminating JIT overhead entirely.

**Build for DGX Spark:**
```bash
# On DGX Spark:
git clone https://github.com/ggml-org/whisper.cpp.git
cd whisper.cpp

# Build with SM_121 CUDA support:
cmake -B build \
  -DGGML_CUDA=1 \
  -DCMAKE_CUDA_ARCHITECTURES="121" \
  -DCMAKE_BUILD_TYPE=Release

cmake --build build -j$(nproc)

# Download large-v3 GGML model:
bash models/download-ggml-model.sh large-v3

# Test:
./build/bin/whisper-cli -m models/ggml-large-v3.bin -f test.wav
```

**Advantages:**
- **Zero JIT compilation** — all CUDA kernels compiled at build time
- **Minimal startup** — load model weights, ready to go
- **Lower memory footprint** — supports quantized models (Q5, Q8)
- **C++ server mode** available with HTTP API

**Limitations:**
- No built-in diarization (would still need pyannote separately)
- No WhisperX-style forced alignment
- Would need to wrap in a Python FastAPI service or use its built-in server
- Less mature than faster-whisper for production use

### 6. MEDIUM-TERM: PyTorch Whisper + torch.compile with Persistent Cache
**Effort:** 1-2 days | **Impact:** ~4.5x faster inference + cached compilation | **Risk:** Medium-High

Use HuggingFace Transformers Whisper with `torch.compile()` and persistent caching to eliminate cold start.

**How it works:**
1. First run: `torch.compile()` compiles the model graph (takes ~2-5 min)
2. Cache artifacts saved to disk via `torch.compiler.save_cache_artifacts()`
3. Subsequent runs: load cache, skip compilation

```python
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor

# Load cache from previous run
try:
    with open("/data/torch_compile_cache.bin", "rb") as f:
        torch.compiler.load_cache_artifacts(f.read())
except FileNotFoundError:
    pass

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model = model.to("cuda").half()
model = torch.compile(model, mode="reduce-overhead")

# First inference triggers compilation (cached = fast, uncached = slow)
# ...

# Save cache for next startup
artifacts, cache_info = torch.compiler.save_cache_artifacts()
with open("/data/torch_compile_cache.bin", "wb") as f:
    f.write(artifacts)
```

**Key detail from [Dropbox blog](https://dropbox.github.io/whisper-static-cache-blog/):** Adding a static KV-cache is critical for torch.compile compatibility with Whisper. Without it, dynamic cache sizes cause recompilation.

**Advantages:**
- 4.5x faster inference per [Dropbox benchmark](https://dropbox.github.io/whisper-static-cache-blog/)
- Portable cache artifacts across container restarts
- Same model (Whisper large-v3), same accuracy

**Limitations:**
- Requires PyTorch 2.7+ (nightly builds for SM_121)
- Loses CTranslate2's memory optimizations
- More VRAM than CTranslate2 (no quantization)
- torch.compile cache portability requires same PyTorch + Triton version + same GPU

### 7. LONG-TERM: TensorRT Engine for Whisper
**Effort:** 2-3 days | **Impact:** Fastest possible inference + instant startup after first build | **Risk:** High

[whisper_trt](https://github.com/NVIDIA-AI-IOT/whisper_trt) converts Whisper to a TensorRT engine optimized for the target GPU.

**How it works:**
1. First run: builds TensorRT engine from ONNX (~5-10 min)
2. Engine cached at `~/.cache/whisper_trt/`
3. Subsequent runs: load cached engine (seconds)

**DGX Spark status:** TensorRT has [beta support for DGX Spark](https://forums.developer.nvidia.com/t/tensorrt-for-blackwell-dgx-spark/355279) but with known issues:
- "Skipping tactic" warnings (non-fatal, falls back to working alternatives)
- Missing Blackwell-specific optimizations
- TensorRT-LLM requires building from source for DGX Spark

**Not recommended yet** due to SM_121 software immaturity.

---

## Comparison Matrix

| Solution | Cold Start | Inference Speed | Effort | Risk | Status |
|----------|-----------|----------------|--------|------|--------|
| **CUDA_FORCE_PTX_JIT=0** | **3-8s (cached)** | Same | 5 min | None | **✓ DONE** |
| Build CT2 from source (SASS) | N/A | Same | 8 hrs | High | **✗ Failed** (see Investigation) |
| Hot code reload | 0s (code changes) | Same | 2-4 hrs | Low | Not needed (3s restart) |
| Parakeet-TDT (NeMo) | TBD (<10s) | 50x faster | 1-2 days | Medium | Future |
| whisper.cpp + CUDA | ~2-5s | Similar | 1 day | Medium | Future |
| torch.compile + cache | ~5-10s (cached) | 4.5x faster | 1-2 days | Medium-High | Future |
| TensorRT engine | ~5s (cached) | Fastest | 2-3 days | High | Future |

---

## Warmup Breakdown (measured 2026-02-28)

Total warmup **before fix**: ~276s on every container restart.

| Component | Time | % of Total | Root Cause |
|-----------|------|-----------|------------|
| STT (CTranslate2/Whisper) | ~141s | 51% | PTX→SASS JIT for hundreds of Whisper kernels |
| Speaker ID (pyannote) | ~76s | 28% | PyTorch CUDA JIT for segmentation + embedding models |
| Other (wav2vec2, setup) | ~59s | 21% | Alignment model + Python imports |

### What Actually Fixed It

**Setting `CUDA_FORCE_PTX_JIT=0` in `run.sh`** lets the CUDA runtime use the JIT cache:

| Scenario | Warmup | Notes |
|----------|--------|-------|
| Before (CUDA_FORCE_PTX_JIT=1) | ~276s | Cache populated but never used |
| After, first restart (warm cache) | 8.1s | Cache loaded from disk |
| After, subsequent restarts | 3.1s | Cache warm in OS page cache |
| After cold cache (image rebuild) | ~276s once | Must JIT once to populate cache |

**97.7% reduction in warmup time** (276s → 6.5s total startup including model load).

## Investigation: CTranslate2 SASS Build (Failed)

**Status:** Attempted and abandoned (2026-02-28). Documented for future reference.

### What we tried

Built CTranslate2 from source to embed native SASS binaries, eliminating PTX JIT entirely.

#### Attempt 1: CUDA 13.0 nvcc with SM_121 SASS
- Script: `services/audio-pipeline/build-ct2-sm121.sh`
- Mounted host's CUDA 13.0 toolkit (only nvcc that supports `compute_121`)
- Built with `-DCUDA_ARCH_LIST="9.0" -DCUDA_NVCC_FLAGS="-gencode arch=compute_121,code=sm_121"`
- `cuobjdump` confirmed SM_121 SASS was in the library
- **Result:** `cudaErrorNoKernelImageForDevice` — CUDA 12.8 runtime can't parse fat binaries generated by CUDA 13.0 nvcc

#### Attempt 2: CUDA 12.8 nvcc with SM_90 SASS
- Used container's built-in CUDA 12.8 (no host mount needed)
- Built with `-DCUDA_ARCH_LIST="9.0"` → generates `sm_90` SASS
- **Result:** `cudaErrorNoKernelImageForDevice` — CUDA runtime won't select SM_90 SASS for SM_121 device (exact architecture match required)

#### Attempt 3: LD_PRELOAD to spoof device capability
- Created `cuda_arch_spoof.so` to intercept `cuDeviceGetAttribute` and return SM_90
- Tested with CUDA_FORCE_PTX_JIT=0 to let runtime use SASS
- **Result:** Failed — CUDA runtime's kernel selection is internal, doesn't call the public `cuDeviceGetAttribute` API through the PLT. LD_PRELOAD has no interception point.

### Why SASS can't work on SM_121 with CUDA 12.8

1. **SM_121 SASS compiled with CUDA 13.0:** Fat binary registration code uses CUDA 13.0 format. The container's CUDA 12.8 runtime can't parse it.
2. **SM_90 SASS compiled with CUDA 12.8:** CUDA runtime selects kernels by EXACT architecture match. Device is SM_121, not SM_90 → no match → fallback to PTX → no PTX → error.
3. **LD_PRELOAD spoofing:** Kernel selection happens inside the CUDA runtime via internal mechanisms, not through the public driver API. Can't intercept.
4. **CUDA_FORCE_PTX_JIT=1:** Base image sets this, which forces runtime to IGNORE all SASS and use only PTX. Even if SASS matching worked, this flag defeats it.

### What would actually work for SASS

- **Upgrade to CUDA 13.0+ runtime** in the container — then CUDA would understand SM_121 SASS natively
- **Wait for pip wheels with SM_121 SASS** — when CUDA 13.0 becomes mainstream
- **Use a different base image** with CUDA 13.0+

### Build artifacts (kept for reference)

The build script remains at `services/audio-pipeline/build-ct2-sm121.sh` for documentation purposes. Build artifacts are gitignored.

### CTranslate2 CMake gotcha

CTranslate2 uses **legacy FindCUDA** (cmake_minimum_required 3.7), NOT modern cmake CUDA language. These variables are **silently ignored**:
- `CMAKE_CUDA_COMPILER`
- `CMAKE_CUDA_ARCHITECTURES`

Must use instead:
- `CUDA_TOOLKIT_ROOT_DIR` → tells FindCUDA where to find nvcc
- `CUDA_ARCH_LIST` → CTranslate2's own variable for arch targeting
- `CUDA_NVCC_FLAGS` → extra gencode flags
- `OPENMP_RUNTIME=COMP` → use GCC's libgomp on ARM64 (Intel libiomp5 unavailable)

## Recommended Action Plan

### Phase 1: Quick Wins (Done ✓)

1. ~~**Fix CUDA JIT cache**~~ — **Done.** Set `CUDA_FORCE_PTX_JIT=0` in `run.sh`. The base image's `=1` was forcing re-JIT on every kernel launch, defeating the 629 MB cache. Fix: 276s → 3.1s on restart (97.7% reduction).

2. ~~**Investigate CTranslate2 SASS build**~~ — **Done, abandoned.** SASS-based approaches fail because CUDA 12.8 runtime can't select SM_90 SASS for SM_121 device, CUDA 13.0-compiled SM_121 SASS uses incompatible fat binary format, and LD_PRELOAD can't intercept internal kernel selection. See "Investigation" section above.

**Current state:** Cold start is solved for container restarts (3-8s with warm cache). Only cold after image rebuilds (~276s first run to populate cache).

### Phase 2: Future Improvements

3. **Hot code reload during development** — `uvicorn-hmr` or volume-mounted source (already done) + `docker restart` (now only 3-8s with warm cache, acceptable).

4. **Pre-warm cache after image rebuild** — Script that runs the warmup pipeline once after `docker build`, populating the CUDA cache for subsequent starts.

### Phase 3: Evaluate Alternatives (Next Sprint)

5. **Benchmark Parakeet-TDT on DGX Spark** — If cold start and inference speed matter long-term, Parakeet is the strategic choice. NVIDIA is investing heavily in it, and it has native Blackwell support.

6. **Upgrade to CUDA 13.0 base image** — When available, this would enable SM_121 SASS natively, eliminating even first-run JIT. Watch for NGC containers with CUDA 13.0.

---

## Key References

- [CTranslate2 GitHub](https://github.com/OpenNMT/CTranslate2) — source, build instructions
- [CTranslate2 Blackwell cuBLAS fix (#1865)](https://github.com/OpenNMT/CTranslate2/issues/1865) — INT8 failure on SM_120+
- [WhisperX Blackwell compatibility (#1211)](https://github.com/m-bain/whisperX/issues/1211) — CUDA 12.8 + PyTorch 2.7 fix
- [Mekopa whisperx-blackwell](https://github.com/Mekopa/whisperx-blackwell) — architecture spoofing bridge
- [NVIDIA Blackwell Compatibility Guide](https://docs.nvidia.com/cuda/blackwell-compatibility-guide/index.html) — PTX JIT and fat binaries
- [CUDA Fat Binaries and JIT Caching](https://developer.nvidia.com/blog/cuda-pro-tip-understand-fat-binaries-jit-caching/) — CUDA_CACHE_PATH, CUDA_CACHE_MAXSIZE
- [Backend.AI: Is DGX Spark Actually Blackwell?](https://www.backend.ai/blog/2026-02-is-dgx-spark-actually-a-blackwell) — SM_121 vs SM_100 deep dive
- [DGX Spark SM_121 Software Support Lacking](https://forums.developer.nvidia.com/t/dgx-spark-sm121-software-support-is-severely-lacking-official-roadmap-needed/357663)
- [Running Parakeet on Spark](https://forums.developer.nvidia.com/t/running-parakeet-speech-to-text-on-spark/356353) — community benchmarks
- [NVIDIA Riva ASR NIM Support Matrix](https://docs.nvidia.com/nim/riva/asr/latest/support-matrix.html) — Parakeet on DGX Spark
- [Dropbox: Faster Whisper with torch.compile](https://dropbox.github.io/whisper-static-cache-blog/) — 4.5x speedup
- [PyTorch torch.compile Caching Tutorial](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) — portable cache artifacts
- [whisper.cpp](https://github.com/ggml-org/whisper.cpp) — C++ Whisper with CUDA
- [whisper_trt](https://github.com/NVIDIA-AI-IOT/whisper_trt) — TensorRT Whisper
- [Best Open Source STT 2026 Benchmarks](https://northflank.com/blog/best-open-source-speech-to-text-stt-model-in-2026-benchmarks)
- [NVIDIA Parakeet TDT 0.6B v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) — multilingual ASR
- [TensorRT for DGX Spark](https://forums.developer.nvidia.com/t/tensorrt-for-blackwell-dgx-spark/355279) — beta status
- [vLLM SM_121 Feature Request](https://github.com/vllm-project/vllm/issues/31128) — community tracking
- [natolambert DGX Spark Setup Guide](https://github.com/natolambert/dgx-spark-setup) — ML training on GB10

---

## Answers to Specific Questions

### Q1: What exactly causes the 150s first-inference delay?
**PTX-to-SASS JIT compilation.** CTranslate2 wheels ship PTX (intermediate GPU code) for SM_90 at most. The CUDA driver must compile this to native SASS for SM_121 at runtime. Hundreds of unique kernels across Whisper, pyannote, and wav2vec2 accumulate to ~150s.

### Q2: Is it cuBLAS algorithm selection? cuDNN autotuning? Something else?
**Primarily PTX JIT.** cuBLAS and cuDNN contribute but secondarily — their internal kernels also ship as PTX and need JIT compilation. Algorithm selection/autotuning adds milliseconds, not minutes.

### Q3: Can CTranslate2's algorithm selection be cached/persisted to disk?
CTranslate2 itself has no algorithm cache. The underlying CUDA JIT cache (`CUDA_CACHE_PATH`) is the mechanism — and it works excellently when `CUDA_FORCE_PTX_JIT=0` is set. With the fix: 276s → 3-8s on restart.

### Q4: Does faster-whisper or CTranslate2 have any config for persisting computation plans?
**No.** CTranslate2 has `CT2_CUDA_CACHING_ALLOCATOR_CONFIG` for memory allocation tuning but nothing for computation plan caching. The only caching is CUDA's driver-level PTX cache.

### Q5: Would converting to TensorRT help?
**Yes, but SM_121 support is beta.** TensorRT engines are pre-compiled for the target GPU — no JIT needed. But TensorRT on DGX Spark has known issues (missing tactics, fallback warnings). Not recommended until NVIDIA stabilizes SM_121 support.

### Q6: Alternative Whisper implementations?
- **whisper.cpp + CUDA:** Compiles kernels at build time. Best cold-start story. No diarization built in.
- **PyTorch Whisper + torch.compile:** 4.5x faster, portable cache. Requires PyTorch 2.7+.
- **Parakeet-TDT (NeMo):** 50x faster, native Blackwell. English focus. Best long-term choice.
- **NVIDIA Riva NIM:** Production-grade, DGX Spark supported (Parakeet 1.1B CTC/RNNT only).

### Q7: Keep model "warm" across code reloads?
**Yes — uvicorn-hmr.** Hot module reloading without process restart. Main process (and GPU models) stay alive. Only changed Python modules are reloaded. Alternative: separate model server process.

### Q8: Has anyone else reported this on Blackwell/SM_121?
**Yes, widely.** The [DGX Spark forums](https://forums.developer.nvidia.com/t/dgx-spark-sm121-software-support-is-severely-lacking-official-roadmap-needed/357663) have extensive reports. TensorFlow users report [30+ minutes of JIT compilation](https://github.com/tensorflow/tensorflow/issues/89272). The mekopa/whisperx-blackwell project was created specifically to work around these issues.

### Q9: Does the NGC "may not yet be supported" warning mean SM_121 runs in compatibility mode?
**Yes.** SM_121 runs SM_90 (Hopper) code via binary compatibility. The whisperx-blackwell container patches `get_device_capability()` to return (9, 0) so libraries accept it. This is functional but means **every CUDA kernel is JIT-compiled from SM_90 PTX to SM_121 SASS** on first use.
