# Research: Why PTQ Fails and QAT Wins for NVFP4 Behavioral Preservation

**Date:** 2026-03-15
**Context:** Sessions 337-340, NVFP4 quantization of Opus-Distilled 9B for Annie Voice on DGX Spark
**Key Discovery:** QAT (Quantization-Aware Training) recovers behavioral quality that PTQ cannot

---

## The Problem We Set Out to Solve

Annie Voice runs on a Qwen3.5-9B model fine-tuned with Claude Opus 4.6 reasoning traces
(Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled). In Q4_K_M format on llama-server,
it works perfectly: concise responses, proper tool calling, no markdown, no thinking leaks.

NVFP4 quantization promises transformative speed:
- **TTFT: 90ms constant** (vs Q4_K_M's 300-1600ms growing with context)
- **Decode: 33 tok/s** (matches Q4_K_M)
- **DeltaNet linear attention**: O(1) prefill regardless of context length

But naive NVFP4 quantization (PTQ) destroys the behavioral fine-tuning:
- Tool calling: 0/10 (emits intent as text, can't produce JSON format)
- Markdown suppression: 3/15 (80% leak rate)
- Thinking leak: 5/25 (20% — leaks `<think>` blocks as content)
- Factual accuracy: 5/5 (preserved perfectly)
- Reasoning: 5/5 (preserved perfectly)

**FP4 preserves knowledge but destroys manners.**

---

## What We Tried (Chronological)

### Attempt 1: Better Calibration Data (PTQ)

**Hypothesis:** The v1 model used CNN DailyMail (news articles) for calibration. Annie-specific
calibration data should preserve the activation ranges that matter for tool calling and formatting.

**What we did:**
1. Generated 430 Annie conversations using Claude Opus 4.6 as teacher via Claude Code CLI
   - 100 greetings, 100 factual, 100 search_memory tool calls, 100 web_search tool calls, 30 multi-turn
   - 200/430 (47%) include proper `tool_calls` with correct function names and arguments
   - Zero markdown contamination
2. Re-quantized with NVFP4_DEFAULT_CFG + Annie calibration data (2.7 min)
3. Served with vLLM Docker cu130-nightly

**Result:** Identical failures to v1. Thinking leak, broken tool calling, verbose responses.
The calibration data domain had ZERO effect on quality.

**Why it failed:** DEFAULT_CFG uses the "max" algorithm which measures min/max activation ranges
per channel. The calibration data determines WHAT activations are measured, but the algorithm
still rounds the same way. The behavioral patterns are small weight adjustments that fall within
the same min/max range as knowledge weights — they get rounded identically.

### Attempt 2: AWQ-Lite Algorithm (PTQ)

**Hypothesis:** AWQ-Lite (Activation-Aware Weighting) finds optimal per-channel scaling factors
that minimize quantization error. This should better preserve the fine-tuning signal.

**What we did:**
- Quantized with NVFP4_AWQ_LITE_CFG + Annie calibration data (62.7 min — 23x slower)

**Result:** Produced `MIXED_PRECISION` format in hf_quant_config.json. vLLM cu130-nightly
returns "Unknown ModelOpt quant algo: MIXED_PRECISION". Cannot serve.

**The model exists but can't be tested.** Located at `~/models/Qwen3.5-9B-Opus-Distilled-NVFP4-v2a/`.

### Attempt 3: Prompt Engineering on v1 (No requantization)

**Hypothesis:** If the behavioral signal is partially present, a stronger system prompt might
activate it.

**What we did:** Tested 7 hardened prompts on v1 NVFP4 with reasoning parser enabled.

**Result:** 43% reliability (3/7 pass). Three failure modes:
1. Inconsistent thinking format — parser catches `<think>` tags 57% of the time
2. Tool call format broken — model knows the right tool but emits XML text, not API format
3. When thinking leaks, markdown everywhere, all tokens consumed by reasoning

**Diagnostic value:** The signal IS partially there (~40-60%) but unreliable.

---

## The Discovery: PTQ vs QAT

### What is PTQ (Post-Training Quantization)?

PTQ quantizes a model AFTER training is complete:
1. Load trained model (BF16/FP16)
2. Run calibration data through the model to measure activation ranges
3. Compute quantization scales based on observed ranges
4. Clip and round weights to FP4 format
5. Export

**PTQ never modifies weights.** It only measures and rounds. The behavioral patterns — small,
precise weight adjustments from fine-tuning — are collateral damage.

### What is QAT (Quantization-Aware Training)?

QAT trains the model WITH quantization active:
1. Load trained model (BF16/FP16)
2. Insert fake quantizers (simulated FP4 in forward pass)
3. Fine-tune with the quantizers active:
   - Forward: weights are quantized (model "sees" FP4 precision)
   - Backward: straight-through estimator passes gradients to full-precision weights
   - Weights ADAPT to produce correct output despite quantization noise
4. Export quantized model

**QAT modifies weights to work well under quantization.** The model learns which weight values
matter most and adjusts them to survive FP4 rounding.

### Evidence: PTQ vs QAT Accuracy

**From NVIDIA's Nemotron Super benchmarks (NVFP4):**

| Benchmark | BF16 Baseline | PTQ | QAT/QAD | Gap Recovered |
|-----------|---------------|-----|---------|---------------|
| Math-500 | 0.96 | 0.90 | **0.96** | 100% |
| AIME 2024 | 0.58 | 0.36 | **0.58** | 100% |
| GPQA Diamond | 0.64 | 0.60 | **0.64** | 100% |

Source: [NVIDIA QAT Blog](https://developer.nvidia.com/blog/how-quantization-aware-training-enables-low-precision-accuracy-recovery/)

**From LMSYS MXFP4 research (behavioral alignment):**

| Task | PTQ | QAT | Gap |
|------|-----|-----|-----|
| Safety alignment (FalseReject) | 59% | **97%** | +38 points |

Source: [LMSYS QAT Blog](https://lmsys.org/blog/2025-08-28-gpt-oss-qat/)

**Key insight from NVIDIA:** "QAT fine-tuning for even less than 1% of the original pre-training
time is often sufficient to restore the model's quality."

### Why This Matters for Annie

The Opus-Distilled model was trained on ~3,950 samples. 1% = ~40 samples. We have 430 Annie
conversations. **We have 10x the required QAT data.**

The behavioral patterns we need to preserve (tool calling format, markdown suppression,
conciseness, thinking control) are exactly the kind of "alignment" patterns that PTQ destroys
and QAT recovers. The LMSYS safety alignment result (59% → 97%) is directly analogous to our
tool calling result (0% → ???).

---

## How the Original Model Was Fine-Tuned

Source: [Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled](https://huggingface.co/Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled)

| Aspect | Detail |
|--------|--------|
| Base model | Qwen/Qwen3.5-9B |
| Framework | Unsloth (memory/compute optimization) |
| Method | SFT + LoRA with `train_on_responses_only` |
| Loss masking | Instructions masked; loss on `<think>` + answer tokens only |
| Context window | 16,384 tokens |
| Training loss | 0.5138 → 0.35786 (healthy convergence) |
| LoRA details | Rank, alpha, target modules NOT published |

**Datasets (3 sources, ~3,950 samples total):**

| Dataset | Purpose |
|---------|---------|
| [nohurry/Opus-4.6-Reasoning-3000x-filtered](https://huggingface.co/datasets/nohurry/Opus-4.6-Reasoning-3000x-filtered) | Claude 4.6 Opus reasoning trajectories |
| [TeichAI/claude-4.5-opus-high-reasoning-250x](https://huggingface.co/datasets/TeichAI/claude-4.5-opus-high-reasoning-250x) | High-intensity structured reasoning |
| [Jackrong/Qwen3.5-reasoning-700x](https://huggingface.co/datasets/Jackrong/Qwen3.5-reasoning-700x) | Curated reasoning for diversity |

**Output format:** `<think> {internal reasoning} </think>\n {final answer}`

**Implication for QAT:** The LoRA fine-tuning distributed behavioral patterns across ALL model
layers (not concentrated in specific ones). This is why layer exclusions in PTQ can't selectively
protect them — you'd need to exclude everything. QAT solves this by retraining ALL weights to
accommodate quantization.

---

## ModelOpt QAT Recipe (from NVIDIA examples)

**Repository:** [NVIDIA/Model-Optimizer/examples/llm_qat](https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/llm_qat/README.md)

### QAT with LoRA + NVFP4 (recommended for our use case)

```bash
# From inside NGC container:
./launch.sh \
  --model /models/Qwen3.5-9B-Opus-Distilled-BF16 \
  --num_epochs 0.5 \
  --lr 1e-3 \
  --do_train True \
  --output_dir /models/Qwen3.5-9B-Opus-Distilled-NVFP4-QAT \
  --quant_cfg NVFP4_DEFAULT_CFG \
  --compress True \
  --lora True
```

### Key Parameters

| Parameter | Recommended | Notes |
|-----------|-------------|-------|
| Learning rate | 1e-3 to 1e-5 | Lower = safer, higher = faster convergence |
| Epochs | 0.5-3 | "Less than 1% of pre-training time" |
| LoRA | Yes | Memory efficient, prevents catastrophic forgetting |
| quant_cfg | NVFP4_DEFAULT_CFG | Produces vLLM-compatible format |
| compress | True | Exports real FP4 (not fake quantizers) |

### Custom QAT Script (for our Annie data)

```python
import modelopt.torch.quantization as mtq
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer

# 1. Load model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# 2. Insert quantizers (same as PTQ, but don't freeze)
def forward_loop(model):
    for batch in calib_data:
        model(batch["input_ids"].cuda())

model = mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop)
# NOTE: Do NOT call model.eval() — keep it trainable

# 3. QAT fine-tune on Annie conversations
training_args = TrainingArguments(
    output_dir="./qat_output",
    num_train_epochs=1,
    learning_rate=1e-5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    bf16=True,
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=annie_dataset,
    tokenizer=tokenizer,
)
trainer.train()

# 4. Export
from modelopt.torch.export import export_hf_checkpoint
export_hf_checkpoint(model, export_dir=output_path)
```

---

## DGX Spark Serving Challenges (SM 12.1)

### What Doesn't Work
- pip-installed vLLM/SGLang (no SM 12.1 CUDA kernels)
- Docker images with CUDA <13.0
- CUDA graphs (`--enforce-eager` mandatory)
- `MIXED_PRECISION` quant format in vLLM
- transformers 5.3 tokenizer (`TokenizersBackend` class unknown to older transformers)
- `qwen3_5_text` model_type (unknown to transformers 4.57 in Docker)

### What Works
- Docker `vllm/vllm-openai:cu130-nightly` (CUDA 13.0 native SM 12.1 support)
- `--quantization modelopt_fp4` (not `modelopt`)
- VL config wrapper (`model_type: "qwen3_5"`, `Qwen3_5ForConditionalGeneration`)
- Tokenizer from working model (copy from 27B or v1)
- `preprocessor_config.json` stub (copy from 27B)
- `--language-model-only` (skips vision encoder init)

---

## CLI Learnings (Claude Code for Data Generation)

### What Works
- `subprocess.run()` with temp file for prompt (not stdin pipe)
- `--append-system-prompt "You are a data generator. Output ONLY the requested data format..."` (prevents CLI from describing the task)
- `--model opus` for highest quality calibration data
- 10 conversations per batch (50 returns empty)
- Running from `/tmp` to avoid project hooks

### What Doesn't Work
- `--no-session-persistence` with Opus (causes indefinite hangs)
- Shell pipes with subshells `echo | (cd /tmp && claude -p ...)` (output swallowed)
- `--max-turns 1` (flag doesn't exist)
- Batches of 50+ conversations (CLI output size limit)
- Multi-turn with <120s timeout (11-message conversations need more time)
- Global stop hooks with `-p` mode (hook output contaminates model output)
- **Long prompts with JSON/code-like content** — Claude enters "planning mode" and outputs "Plan is ready for your review" instead of data. Keep generation prompts <3000 chars, use condensed tool lists, not full JSON schemas. Add explicit anti-planning instructions to `--append-system-prompt`.

---

## Quantization Timing Comparison

| Strategy | Algorithm | Time | Output Format | vLLM Compatible |
|----------|-----------|------|---------------|----------------|
| DEFAULT_CFG (v1) | max | 3.8 min | NVFP4 | Yes |
| DEFAULT_CFG (v2d) | max | 2.7 min | NVFP4 | Yes |
| AWQ_LITE (v2a) | awq_lite | 62.7 min | MIXED_PRECISION | No |
| QAT (next) | max + fine-tune | ~30-60 min est. | NVFP4 | Yes (expected) |

---

## Next Steps: QAT Implementation Plan

### Phase 1: Adapt ModelOpt QAT Example
1. Clone `NVIDIA/Model-Optimizer/examples/llm_qat/`
2. Modify `launch.sh` for Qwen3.5-9B-Opus-Distilled
3. Point to our Annie calibration JSONL (430 conversations)
4. Run inside NGC container on Titan

### Phase 2: QAT Fine-Tuning
1. Insert NVFP4 quantizers with DEFAULT_CFG
2. LoRA fine-tune on Annie data (LR 1e-5, 0.5-1 epoch)
3. Export with `--compress True`

### Phase 3: Serve and Benchmark
1. Patch config.json (VL wrapper)
2. Copy tokenizer + preprocessor from working model
3. Serve with Docker cu130-nightly + `--quantization modelopt_fp4`
4. Run benchmark_quant_v3.py

### Quality Gates (must beat v1 PTQ)
| Gate | v1 PTQ | Target |
|------|--------|--------|
| Thinking leak | 5/25 (20%) | 0/25 (0%) |
| Tool calling | 0/10 (0%) | ≥6/10 (60%) |
| No markdown | 3/15 (20%) | ≥12/15 (80%) |
| TTFT | 90ms | ≤120ms |
| Decode | 33 tok/s | ≥25 tok/s |

---

## Sources

- [Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled (HuggingFace)](https://huggingface.co/Jackrong/Qwen3.5-9B-Claude-4.6-Opus-Reasoning-Distilled)
- [NVIDIA QAT Blog: How QAT Enables Low-Precision Accuracy Recovery](https://developer.nvidia.com/blog/how-quantization-aware-training-enables-low-precision-accuracy-recovery/)
- [NVIDIA QAD Report (PDF)](https://research.nvidia.com/labs/nemotron/files/NVFP4-QAD-Report.pdf)
- [LMSYS: Fine-tune gpt-oss MXFP4 with QAT](https://lmsys.org/blog/2025-08-28-gpt-oss-qat/)
- [NVIDIA Introducing NVFP4 Blog](https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/)
- [NVIDIA Model-Optimizer QAT Examples](https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/llm_qat/README.md)
- [NVIDIA NVFP4 on DGX Spark](https://build.nvidia.com/spark/nvfp4-quantization)
- [Red Hat: Accelerating LLMs with NVFP4](https://developers.redhat.com/articles/2026/02/04/accelerating-large-language-models-nvfp4-quantization)
