# Nemotron 3 Nano 30B QAT Training — Research & Pipeline

**Date:** 2026-03-17 (Session 350)
**Status:** Smoke test phase — mamba-ssm blocker identified, Unsloth path validated
**Goal:** Train Nemotron 3 Nano to be Annie (behavioral fine-tuning) → single model for voice + extraction

## 1. Why Train Nemotron?

Currently running 2 vLLM containers:
- **Qwen3.5-9B QAT-v4** (port 8003): Annie Voice — behavioral fine-tuned, 10/10 pass
- **Nemotron 3 Nano 30B NVFP4** (port 8004): Extraction — 4x faster, generic personality

If we QAT-train Nemotron with Annie's persona, we can run **ONE model for everything**:
- Save 19 GB VRAM (no more Qwen 9B)
- Simpler architecture (1 container vs 2)
- 4x faster extraction AND voice

## 2. Training Data Available

| Dataset | Conversations | Categories | Notes |
|---------|--------------|-----------|-------|
| **v3** (calibration_v3/) | 2,000 | 7 (topic_switch, direct_action, concise_voice, honest_memory, rajesh_context, kannada_culture, mixed_multiturn) | Full production system prompt |
| **v4** (calibration_v4/) | 1,998 | 9 (v3 + no_emoji_no_markdown, tool_result_responses, no_followup_questions) | Adversarial, current `<rules>` block |
| **Total** | **3,998** | 9 adversarial categories | Generated by Opus 4.6 teacher |

All conversations are in JSONL format: `{"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}`

## 3. Training Approaches Evaluated

### 3a. Full QAT with ModelOpt (our Qwen 9B approach)

```
Load BF16 (63 GB) → mtq.quantize(NVFP4_DEFAULT_CFG) → Train → Export
```

| Pro | Con |
|-----|-----|
| Proven on Qwen 9B (v1-v4) | BF16 model = 63 GB download |
| modelopt handles quantization | Training peak ~95-110 GB with LoRA |
| Export directly to NVFP4 serving format | `mamba-ssm` package BLOCKED on aarch64 |
| Exact recipe in `scripts/qat_nvfp4.py` | NGC PyTorch container lacks dependencies |

**BLOCKER:** `mamba-ssm` cannot be pip-installed on DGX Spark (aarch64 + CUDA 13.0). No pre-built wheels exist. Build from source fails with `torch._C._dlpack_exchange_api` error.

### 3b. Unsloth LoRA Fine-Tuning (RECOMMENDED)

```
Load BF16 via Unsloth (handles mamba-ssm internally) → LoRA → Train → Merge → Export GGUF/NVFP4
```

| Pro | Con |
|-----|-----|
| **Day-zero Nemotron 3 Nano support** | Need to install Unsloth on Titan |
| Handles mamba-ssm dependency internally | BF16 model still 63 GB download |
| ~60 GB VRAM for 16-bit LoRA | Output is GGUF — need re-quantize to NVFP4? |
| Router frozen by default (MoE best practice) | Unknown: does Unsloth work on aarch64? |
| Notebook available for Nemotron 30B | |

**Source:** [Unsloth Nemotron 3 Guide](https://docs.unsloth.ai/models/nemotron-3)

### 3c. LoRA SFT on NVFP4 (fastest test, uncertain quality)

```
Load NVFP4 (18 GB, already on Titan) → LoRA → Train → Merge
```

| Pro | Con |
|-----|-----|
| No 63 GB download needed | Gradients through quantized weights = noisy |
| Model already on Titan | Quality may be worse than BF16-based training |
| ~35-45 GB peak VRAM | No precedent for this approach |
| Fastest to test | Export path unclear |

## 4. Smoke Test Results

### Smoke Test 1: Load NVFP4 + modelopt QAT

```
Container: nvcr.io/nvidia/pytorch:25.11-py3
Result: FAILED
Error: ImportError: mamba-ssm is required by the Mamba model
Root cause: mamba-ssm PyPI has no aarch64 wheel, build fails on CUDA 13
```

### Smoke Test 2: Install mamba-ssm from source

```
Result: FAILED
Error: AttributeError: module 'torch._C' has no attribute '_dlpack_exchange_api'
Root cause: PyTorch 2.10.0 in NGC container incompatible with mamba-ssm build
```

### Smoke Test 3: Patch rmsnorm_fn + Direct HF + BitsAndBytes 4-bit + PEFT LoRA

```
Approach: Patch modeling_nemotron_h.py to replace mamba-ssm rmsnorm_fn with pure PyTorch fallback
Container: nvcr.io/nvidia/pytorch:25.11-py3 + pip install peft bitsandbytes transformers
Result: SUCCESS ✓
  - Model loaded: NemotronHForCausalLM, 10.3 GB GPU (4-bit bitsandbytes)
  - LoRA added: 1,867,776 trainable params (0.01% of 16.3B)
  - Target modules: q_proj, k_proj, v_proj, o_proj (attention only, MoE untouched)
```

**Key fix:** `scripts/patch_nemotron_mamba.py` — replaces the hard `mamba-ssm` import
with a pure PyTorch `rmsnorm_fn` implementation (~30 lines). The function is just
`x * rsqrt(mean(x²) + eps) * silu(gate) * weight` — no Triton kernels needed for training.

## 5. Dependency Matrix (DGX Spark aarch64)

| Package | PyPI Wheel (aarch64) | NGC Container | Build from Source | Status |
|---------|---------------------|---------------|-------------------|--------|
| `torch` | No (needs cu130) | 25.11: 2.10.0 | N/A (use NGC) | OK in container |
| `transformers` | Yes | 25.11: NO, vLLM: 4.57 | N/A | Need pip install |
| `nvidia-modelopt` | ? | 25.11: 0.37.0 | N/A | OK in PyTorch container |
| `mamba-ssm` | **NO** | **NO** | **FAILS** | **BLOCKER** |
| `causal-conv1d` | **NO** | **NO** | Not tested | Likely same issue |
| `peft` | Yes | NO | N/A | pip install OK |
| `unsloth` | ? | NO | ? | **NEEDS TESTING** |

## 6. Recommended Pipeline (Unsloth)

### Step 1: Download BF16 Model (~63 GB, ~30-60 min)
```bash
ssh titan "~/workplace/her/her-os/services/annie-voice/.venv/bin/huggingface-cli download \
    nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \
    --local-dir ~/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"
```

### Step 2: Install Unsloth on Titan
```bash
# Option A: pip install (may need aarch64 compatibility check)
pip install unsloth

# Option B: Docker (if Unsloth provides an image)
docker pull unsloth/unsloth:latest

# Option C: Conda (most reliable for complex deps)
conda install unsloth -c conda-forge
```

### Step 3: Prepare Training Data
```bash
# Merge v3 + v4 data into single training file
cat data/calibration_v3/*.jsonl data/calibration_v4/*.jsonl > /tmp/annie_train_combined.jsonl
# 3,998 conversations
```

### Step 4: LoRA Fine-Tune with Unsloth
```python
from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
    max_seq_length=2048,
    load_in_4bit=True,  # QLoRA: 4-bit base + LoRA
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_alpha=32,
    lora_dropout=0.05,
)

# Use Hugging Face Trainer with SFTTrainer
from trl import SFTTrainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    max_seq_length=2048,
    args=TrainingArguments(
        num_train_epochs=5,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        bf16=True,
        output_dir="./nemotron-annie-qat",
    ),
)
trainer.train()
```

### Step 5: Export and Quantize
```python
# Save LoRA merged model
model.save_pretrained_merged("./nemotron-annie-bf16", tokenizer)

# Then quantize to NVFP4 using modelopt (in NGC PyTorch container)
# OR export GGUF via Unsloth:
model.save_pretrained_gguf("./nemotron-annie-gguf", tokenizer, quantization_method="q4_k_m")
```

### Step 6: Serve via vLLM
```bash
# If NVFP4 export works:
docker run -d --name vllm-nemotron-annie \
    --gpus all --runtime=nvidia --network host --ipc=host \
    -e VLLM_USE_FLASHINFER_MOE_FP4=1 \
    -e VLLM_FLASHINFER_MOE_BACKEND=latency \
    -v ~/models:/models \
    nvcr.io/nvidia/vllm:25.12.post1-py3 \
    vllm serve /models/Nemotron-3-Nano-Annie-QAT-v1 \
    --served-model-name nemotron-nano \
    --enforce-eager --kv-cache-dtype fp8 \
    --gpu-memory-utilization 0.35

# If GGUF only: use llama-server or Ollama
```

## 7. Existing Scripts (Reusable)

| Script | Purpose | Reuse for Nemotron? |
|--------|---------|-------------------|
| `scripts/qat_nvfp4.py` | QAT training (Qwen 9B) | **Template** — need mamba-ssm fix |
| `scripts/generate_qat_v4_data.py` | Data generation (Opus teacher) | **As-is** — data already generated |
| `scripts/quantize_nvfp4_v2.py` | PTQ quantization | Not needed (model is already NVFP4) |
| `scripts/validate_training_data.py` | Data quality checks | **As-is** |
| `scripts/benchmark_nemotron_vs_qwen27b.py` | Benchmark | **As-is** — verify after training |

## 8. Training Parameters (from Qwen 9B QAT v4)

| Parameter | Qwen 9B (proven) | Nemotron 30B (planned) |
|-----------|-----------------|----------------------|
| Conversations | 999 | 2000-3998 |
| Categories | 9 adversarial | Same 9 categories |
| Epochs | 5 | 3-5 |
| Learning rate | 1e-5 | 2e-5 (LoRA needs higher) |
| Batch size | 1 | 1 |
| Gradient accumulation | 8 | 8 |
| Max sequence length | 1024 | 1024-2048 |
| Training mode | Full FT + modelopt QAT | LoRA (rank 16) via Unsloth |
| Loss | Full sequence | Assistant-only (Unsloth SFTTrainer) |
| Peak VRAM | ~90 GB | ~60 GB (16-bit LoRA) |
| Training time | ~7.2 hours (5 epochs, 999 convos) | ~10-15 hours (estimated) |

## 9. Key Lessons from Qwen QAT v1-v4

1. **FP4 preserves knowledge but destroys manners** — PTQ keeps factual accuracy but breaks formatting, tool calling, personality
2. **QAT = SFT with quantization active** — one training run does both behavioral + quantization
3. **System prompt must match production** — v3 failed because training data had stale prompt (missing `<rules>`)
4. **9 adversarial categories** cover all failure modes: emoji, markdown, verbose, follow-up questions, tool leaks, fabrication, wrong tools, name errors, think leaks
5. **Loss should drop 75%+** — v4 went 1.307→0.329 (successful); v1 only dropped 24% (failed)
6. **Export gotchas**: VL config wrapper, tokenizer version mismatch, preprocessor_config.json stub all needed post-export

### Smoke Test 4: Train on NVFP4 (shape mismatch)

```
Approach: Load NVFP4 model with bitsandbytes NF4 + LoRA, train 10 steps
Result: FAILED at training step 1
Error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x2688 and 1344x10304)
Root cause: NVFP4 weight format (interleaved scales) is incompatible with
  bitsandbytes NF4 dequantization. BitsAndBytes expects standard weight shapes.
```

**Conclusion: BF16 model is REQUIRED for training.** The NVFP4 model cannot be used
as a training base — its quantized weight layout is different from what bitsandbytes expects.

Download started: `nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16` (~63 GB)

## 10. QLoRA Training Issues & Fixes (Session 351)

### Issue 1: MoE dtype mismatch (`index_add_` BFloat16 vs Float)

**Error:** `RuntimeError: index_add_(): self (BFloat16) and source (Float) must have the same scalar type`

**Root cause:** BitsAndBytes NF4 quantization dequantizes expert weights to Float32 for computation,
but the MoE routing buffer `final_hidden_states` is BFloat16. The `index_add_` op requires matching types.

**Fix:** Patch `modeling_nemotron_h.py` line ~888:
```python
# Before:
final_hidden_states.index_add_(0, token_indices, weighted_output)
# After:
final_hidden_states.index_add_(0, token_indices, weighted_output.to(final_hidden_states.dtype))
```

### Issue 2: Dummy expert `fused_dropout` on Byte tensor

**Error:** `NotImplementedError: "fused_dropout" not implemented for 'Byte'`

**Root cause:** Empty expert fallback path reads `expert.down_proj.weight.dtype` to create dummy input.
With BitsAndBytes NF4, weight dtype is `torch.uint8` (Byte). LoRA dropout then fails on uint8 input.

**Fix:** Use compute dtype instead of storage dtype:
```python
# Before:
expert_dtype = expert.down_proj.weight.dtype
# After:
expert_dtype = hidden_states.dtype  # BitsAndBytes stores as uint8, use compute dtype
```

Also cast dummy output: `final_hidden_states = final_hidden_states + dummy_out.to(final_hidden_states.dtype)`

### Issue 3: Tool call arguments format mismatch

**Error:** `TypeError: Can only get item pairs from a mapping.` in Jinja2 chat template

**Root cause:** Training data stores `tool_calls[].function.arguments` as a JSON **string** (OpenAI format),
but Nemotron's chat template uses `tool_call.arguments|items` which requires a **dict**.

**Fix:** Preprocess in dataset class — `json.loads()` the arguments string before passing to `apply_chat_template`.

### Issue 4: Low unmasked token percentage (~4%)

**Observation:** With `max_length=1024`, system prompt + user messages consume ~979 tokens,
leaving only ~45 tokens for assistant content. This means the model trains on very little
useful signal per sample.

**Mitigation:** Increase `max_length` to 2048 for the full training run. The smoke test showed
64 GB peak at 1024 — at 2048, expect ~70-75 GB peak (still within 128 GB budget).

### Issue 5: MoE expert LoRA inflates training time 3x

**Observation:** Targeting MLP modules (`gate_proj`, `up_proj`, `down_proj`) in a 64-expert MoE
means LoRA adapters are injected into every expert. Trainable params: 434M (vs 1.8M attention-only).

**Impact:** 33s/step vs 10.3s/step smoke test. Full training estimate revised:
- 999 samples × 3 epochs × 33s/step = ~27.5 hours (up from 8.5h estimate)
- Peak GPU: 64.2 GB (unchanged, since LoRA params are small per-layer)
- Adapter size: 1.68 GB (large due to expert count)

### Issue 6: Docker creates root-owned files in /tmp

**Observation:** Files saved by NGC container (`--rm`) are root:root, unremovable without sudo.

**Mitigation:** Use `--user $(id -u):$(id -g)` in the Docker run command, or save to a user-owned
directory mounted with matching permissions.

### Issue 7: LoRA merge save_pretrained fails

**Error:** `'list' object has no attribute 'keys'` during `merged_model.save_pretrained()`

**Root cause:** `merge_and_unload()` succeeded (12s, no OOM), but Nemotron's custom config
has nested list fields that HuggingFace's `save_pretrained()` serializer can't handle.
Only config files (28K) were saved, no weight safetensors.

**Workaround options:**
1. **vLLM LoRA serving** — `--lora-modules` loads adapter at runtime (no merge needed)
2. **Manual merge** — Save `state_dict` to safetensors directly instead of `save_pretrained`

## 12. Full Training Results (Session 351)

| Metric | Value |
|--------|-------|
| Model | Nemotron 3 Nano 30B (MoE, 3.2B active) |
| Method | QLoRA (NF4 base + LoRA rank 16) |
| Trainable params | 434.7M / 16.58B (2.62%) |
| Data | 999 conversations (v4 adversarial) |
| Epochs | 3 |
| Max length | 2048 |
| Effective batch | 8 (1 × grad_accum 8) |
| Training time | 14.2 hours |
| Peak GPU | 64.1 GB / 128 GB (50%) |
| Speed | ~136s / optimizer step |

### Loss Curve
```
Step   Loss    Epoch   Phase
  5    13.84   0.04    Warmup
 15    16.46   0.12    Peak (warmup end)
 25    12.68   0.20    Post-warmup drop
 55     7.52   0.44    Rapid decline
125     5.28   1.00    Epoch 1 end (checkpoint saved)
190     3.96   1.52    Epoch 2 mid
250     4.27   2.00    Epoch 2 end (checkpoint saved)
280     3.89   2.24    Epoch 3
325     3.26   2.60    BEST LOSS
375     4.29   3.00    Epoch 3 end (checkpoint saved)
```

Overall: 16.46 → 3.26 best (**80% reduction**).

### Saved Artifacts
- **Adapter**: `~/models/Nemotron-3-Nano-Annie-QLoRA-v1/adapter/` (1.66 GB)
- **Checkpoint 125** (epoch 1): `~/models/Nemotron-3-Nano-Annie-QLoRA-v1_checkpoints/checkpoint-125/`
- **Checkpoint 250** (epoch 2): `~/models/Nemotron-3-Nano-Annie-QLoRA-v1_checkpoints/checkpoint-250/`
- **Checkpoint 375** (epoch 3): `~/models/Nemotron-3-Nano-Annie-QLoRA-v1_checkpoints/checkpoint-375/`
- **Merged BF16**: FAILED — config serialization bug (only configs saved, no weights)

### MoE Patches Required (in `modeling_nemotron_h.py`)
1. `index_add_` dtype cast: `weighted_output.to(final_hidden_states.dtype)`
2. Dummy expert compute dtype: `hidden_states.dtype` instead of `expert.down_proj.weight.dtype`
3. Dummy output cast: `dummy_out.to(final_hidden_states.dtype)`

### Next Steps
1. Attempt vLLM serving with `--lora-modules` on base model
2. OR fix manual merge: `torch.save(merged.state_dict())` + safetensors
3. Run behavioral tests: `scripts/test_annie_conversations.py`
4. Compare throughput: `scripts/benchmark_nemotron_vs_qwen27b.py`

### Issue 8: NVFP4 export blocked by save_pretrained config bug

**Status:** PTQ calibration and quantization succeed (18102 quantizers, 256 calib samples, 1146s).
But ALL export paths fail because Nemotron's nested list config breaks HuggingFace serialization.

**What works:**
- `mtq.quantize(model, custom_cfg, forward_loop)` — quantization completes
- `_export_hf_checkpoint()` — returns `(post_state_dict, hf_quant_config)` with valid NVFP4 config
- `hf_quant_config.json` written with correct `quant_algo: "NVFP4"` and exclude list

**What fails:**
- `export_hf_checkpoint()` → calls `model.save_pretrained()` → `'list' object has no attribute 'keys'`
- `_export_hf_checkpoint()` post_state_dict → safetensors save → also fails (exact error TBD)
- Falls to `torch.save(state_dict)` → 19 GB raw PyTorch format, not servable by vLLM

**Root cause:** Nemotron's custom `configuration_nemotron_h.py` has list-type config fields that
HuggingFace's `_get_tied_weight_keys()` can't iterate. This affects `save_pretrained()` in
transformers 5.3.0 with ModelOpt 0.37.0.

**Next steps:**
1. Debug the safetensors save path — the post_state_dict should be saveable directly
2. OR monkey-patch `configuration_nemotron_h.py` to fix the list→dict conversion
3. OR serve the BF16 merged model directly (works, 90% behavioral pass, uses ~55 GB VRAM)
4. OR try NVIDIA's `hf_ptq.py` script from TensorRT-Model-Optimizer which may handle this

## 13. Open Questions

1. **Does Unsloth work on aarch64 DGX Spark?** — No confirmation yet. The triton kernels may need CUDA 13 support.
2. **BF16 or 4-bit base for LoRA?** — 4-bit (QLoRA) saves memory but may not preserve enough behavioral signal. 16-bit is safer but needs 60 GB.
3. **NVFP4 export from Unsloth?** — Unsloth exports GGUF and HF formats. NVFP4 may need a separate modelopt step.
4. **Will MoE routing be stable after LoRA?** — Router is frozen (Unsloth default), so routing should be unchanged. But attention-only LoRA may not capture all behavioral patterns.
5. **Is Nemotron's chat template compatible with our training data?** — Training data uses `enable_thinking=False`. Need to verify Nemotron's Jinja template handles this.

## Sources

- [Unsloth Nemotron 3 Nano Guide](https://docs.unsloth.ai/models/nemotron-3)
- [Unsloth LoRA Hyperparameters](https://docs.unsloth.ai/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide)
- [Trouble fine-tuning Nemotron (GitHub Discussion)](https://github.com/unslothai/unsloth/discussions/3810)
- [DGX Spark ML Training Setup](https://github.com/natolambert/dgx-spark-setup)
- [DGX Spark CUDA Install Pitfalls](https://forums.developer.nvidia.com/t/dgx-spark-cuda-install-pitfalls-on-ubuntu-24-04-arm64-fixed/349881)
- [NVIDIA Nemotron 3 Nano NVFP4](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4)
- [QAT Cookbook](docs/QAT-COOKBOOK.md) — our proven v1-v4 recipe for Qwen 9B
