API Reference for model-training

model_training

Training pipelines for LoRA fine-tuning and trajectory management.

Classes

D2LTrainConfig

Bases: BaseModel

Pydantic model for D2L training hyperparameters.

Enables validation, JSON serialization (for checkpoint storage), and .model_dump() for MLflow experiment logging.

Attributes:

Name Type Description
base_model_name str

HuggingFace model name for the student/teacher base.

sakana_checkpoint_path str

Path to the Sakana hypernet checkpoint.

num_steps int

Total training steps.

lr float

Learning rate for AdamW optimizer.

alpha float

Blending weight for KL vs CE loss (1.0 = pure KL, 0.0 = pure CE).

temperature float

Softmax temperature for KL divergence computation.

checkpoint_every int

Steps between lightweight checkpoint saves.

full_checkpoint_every int

Steps between full checkpoint saves (incl. optimizer).

checkpoint_dir str

Directory for checkpoint output.

experiment_name str

MLflow experiment name.

dry_run bool

If True, validate tensor shapes then exit.

smoke_test bool

If True, run 5 steps and verify loss trend.

dataset_path str | None

Path to training JSONL file (required for full training).

grad_clip float

Gradient clipping max norm.

warmup_steps int

Number of linear LR warmup steps.

lora_r int

LoRA rank.

max_length int

Maximum tokenizer sequence length.

Functions

train_d2l_qwen3

train_d2l_qwen3(config: D2LTrainConfig) -> dict[str, Any]

Run KL-divergence context distillation training.

Three execution modes controlled by config flags: - dry_run=True: Validate shapes with single forward pass, no optimizer step. - smoke_test=True: Run min(num_steps, 5) steps, assert finite decreasing loss. - default: Full training from dataset with checkpointing and MLflow tracking.

Parameters:

Name Type Description Default
config D2LTrainConfig

Training configuration.

required

Returns:

Type Description
dict[str, Any]

Dictionary with training results: - final_loss: Loss at the last step. - best_loss: Lowest loss seen during training. - num_steps_completed: Number of training steps completed. - checkpoint_dir: Path to checkpoint directory. - shape_summary (dry_run only): Tensor shape validation results.

Source code in libs/model-training/src/model_training/d2l_train.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def train_d2l_qwen3(config: D2LTrainConfig) -> dict[str, Any]:  # noqa: C901
    """Run KL-divergence context distillation training.

    Three execution modes controlled by config flags:
    - dry_run=True: Validate shapes with single forward pass, no optimizer step.
    - smoke_test=True: Run min(num_steps, 5) steps, assert finite decreasing loss.
    - default: Full training from dataset with checkpointing and MLflow tracking.

    Args:
        config: Training configuration.

    Returns:
        Dictionary with training results:
            - final_loss: Loss at the last step.
            - best_loss: Lowest loss seen during training.
            - num_steps_completed: Number of training steps completed.
            - checkpoint_dir: Path to checkpoint directory.
            - shape_summary (dry_run only): Tensor shape validation results.
    """
    import mlflow  # noqa: PLC0415
    import torch  # noqa: PLC0415
    from ctx_to_lora.modeling.hypernet import HyperLoRA  # noqa: PLC0415
    from torch.nn.utils import clip_grad_norm_  # noqa: PLC0415
    from torch.optim import AdamW  # noqa: PLC0415
    from torch.optim.lr_scheduler import (  # noqa: PLC0415
        CosineAnnealingLR,
        LinearLR,
        SequentialLR,
    )
    from transformers import AutoModelForCausalLM, AutoTokenizer  # noqa: PLC0415

    from model_training.d2l_config import build_qwen3_hypernet_config  # noqa: PLC0415
    from model_training.d2l_data import (  # noqa: PLC0415
        generate_needle_dataset,
        load_jsonl,
        split_by_task_id,
    )
    from model_training.d2l_probe import QWEN3_NEXT_CANONICAL_NAME  # noqa: PLC0415
    from model_training.sakana_d2l import (  # noqa: PLC0415
        get_aggregator_config,
        transfer_aggregator_weights,
    )

    # Mode dispatch: dry_run exits after shape validation
    if config.dry_run:
        shape_summary = _dry_run_validate_shapes(config)
        return {"shape_summary": shape_summary, "status": "dry_run_complete"}

    # Smoke test caps steps at 5
    if config.smoke_test:
        config = config.model_copy(update={"num_steps": min(config.num_steps, 5)})

    # Guard: require probe cache when not in smoke_test mode.
    # smoke_test uses generate_needle_dataset and may not have a real probe cache.
    if not config.smoke_test:
        _require_probe_cache(QWEN3_NEXT_CANONICAL_NAME)

    num_steps = config.num_steps
    warmup_steps = config.warmup_steps

    # Load tokenizer and base model
    logger.info("Loading tokenizer: %s", config.base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)

    logger.info("Loading base model: %s", config.base_model_name)
    base_model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name,
        output_hidden_states=True,
    ).eval()

    # Build hypernet config with aggregator config from checkpoint
    logger.info(
        "Building hypernet config from checkpoint: %s",
        config.sakana_checkpoint_path,
    )
    hc = build_qwen3_hypernet_config(
        aggregator_config=get_aggregator_config(config.sakana_checkpoint_path),
        lora_r=config.lora_r,
    )

    # Create hypernet and transfer aggregator weights (freezes aggregator)
    hypernet = HyperLoRA(hc).to(torch.float32)
    hypernet = transfer_aggregator_weights(hypernet, config.sakana_checkpoint_path)
    hypernet.train()

    # Device selection: cuda > mps > cpu
    from shared.hardware import get_best_device  # noqa: PLC0415

    device = torch.device(get_best_device())
    logger.info("Using device: %s", device)
    base_model = base_model.to(device)
    hypernet = hypernet.to(device)

    # Optimizer: only trainable params (head + projections, not frozen aggregator)
    trainable_params = [p for p in hypernet.parameters() if p.requires_grad]
    logger.info("Trainable params: %d", sum(p.numel() for p in trainable_params))
    optimizer = AdamW(trainable_params, lr=config.lr)

    # Scheduler: linear warmup → cosine annealing
    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=0.01, total_iters=warmup_steps),
            CosineAnnealingLR(
                optimizer,
                T_max=max(1, num_steps - warmup_steps),
                eta_min=1e-6,
            ),
        ],
        milestones=[warmup_steps],
    )

    # Data loading
    if config.smoke_test or config.dataset_path is None:
        records = generate_needle_dataset(n=20)
        logger.info("Using needle dataset (%d records)", len(records))
    else:
        all_records = load_jsonl(config.dataset_path)
        records, _ = split_by_task_id(all_records)
        logger.info(
            "Loaded %d training records from %s",
            len(records),
            config.dataset_path,
        )

    if not records:
        raise ValueError("No training records loaded; cannot train on empty dataset.")

    # MLflow setup and training loop
    _setup_mlflow(config)

    best_loss = float("inf")
    final_loss = float("inf")
    step_losses: list[float] = []

    with mlflow.start_run(run_name=f"{config.experiment_name}-step{num_steps}"):
        mlflow.log_params(config.model_dump())

        for step in range(1, num_steps + 1):
            record = records[(step - 1) % len(records)]

            loss, metrics = _training_step(
                record=record,
                base_model=base_model,
                tokenizer=tokenizer,
                hypernet=hypernet,
                hc=hc,
                config=config,
            )

            loss.backward()
            clip_grad_norm_(trainable_params, config.grad_clip)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            step_loss = metrics["total_loss"]
            step_losses.append(step_loss)
            final_loss = step_loss
            if step_loss < best_loss:
                best_loss = step_loss

            mlflow.log_metrics(metrics, step=step)
            logger.info(
                "Step %d/%d — loss=%.4f (kl=%.4f, ce=%.4f)",
                step,
                num_steps,
                metrics["total_loss"],
                metrics["kl_loss"],
                metrics["ce_loss"],
            )

            # Tiered checkpointing
            if step % config.full_checkpoint_every == 0:
                ckpt_path = _save_checkpoint(
                    step=step,
                    hypernet=hypernet,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    config=config,
                    hc=hc,
                    best_loss=best_loss,
                    full=True,
                )
                mlflow.log_artifact(str(ckpt_path))
            elif step % config.checkpoint_every == 0:
                ckpt_path = _save_checkpoint(
                    step=step,
                    hypernet=hypernet,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    config=config,
                    hc=hc,
                    best_loss=best_loss,
                    full=False,
                )
                mlflow.log_artifact(str(ckpt_path))

    # Smoke test assertions
    if config.smoke_test:
        for i, sl in enumerate(step_losses):
            assert torch.isfinite(torch.tensor(sl)), (  # noqa: S101
                f"Smoke test: loss at step {i + 1} is not finite: {sl}"
            )
        assert step_losses[-1] < step_losses[0], (  # noqa: S101
            f"Smoke test: final loss {step_losses[-1]:.4f} not less than "
            f"initial loss {step_losses[0]:.4f}"
        )
        assert any(p.grad is not None for p in trainable_params), (  # noqa: S101
            "Smoke test: no trainable param has non-None gradient after training"
        )

    return {
        "final_loss": final_loss,
        "best_loss": best_loss,
        "num_steps_completed": num_steps,
        "checkpoint_dir": config.checkpoint_dir,
    }

format_for_sft

format_for_sft(
    trajectory: dict[str, Any],
) -> list[dict[str, str]]

Convert a trajectory into SFT-compatible chat format.

Only successful trajectories (outcome == 'success') produce output. Extracts the final step where tests_passed is True as the assistant message.

Parameters:

Name Type Description Default
trajectory dict[str, Any]

A trajectory dict as returned by load_trajectory.

required

Returns:

Type Description
list[dict[str, str]]

A list of 3 message dicts ([system, user, assistant]) for successful

list[dict[str, str]]

trajectories, or an empty list if the trajectory did not succeed.

Source code in libs/model-training/src/model_training/trajectory.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def format_for_sft(trajectory: dict[str, Any]) -> list[dict[str, str]]:
    """Convert a trajectory into SFT-compatible chat format.

    Only successful trajectories (outcome == 'success') produce output.
    Extracts the final step where tests_passed is True as the assistant message.

    Args:
        trajectory: A trajectory dict as returned by load_trajectory.

    Returns:
        A list of 3 message dicts ([system, user, assistant]) for successful
        trajectories, or an empty list if the trajectory did not succeed.
    """
    if trajectory.get("outcome") != "success":
        return []

    steps: list[dict[str, Any]] = trajectory.get("steps", [])
    successful_step = next(
        (s for s in reversed(steps) if s.get("tests_passed")),
        None,
    )

    if successful_step is None:
        return []

    task_description: str = trajectory.get("task_description", "")
    generated_code: str = successful_step.get("generated_code", "")

    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": task_description},
        {"role": "assistant", "content": generated_code},
    ]

load_trajectory

load_trajectory(trajectory_id: str) -> dict[str, Any]

Load a stored trajectory by session ID.

Parameters:

Name Type Description Default
trajectory_id str

The session ID used as the filename (without .json).

required

Returns:

Type Description
dict[str, Any]

A dict containing the full trajectory data including steps and metadata.

Raises:

Type Description
FileNotFoundError

If no trajectory file exists for the given ID.

Source code in libs/model-training/src/model_training/trajectory.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def load_trajectory(trajectory_id: str) -> dict[str, Any]:
    """Load a stored trajectory by session ID.

    Args:
        trajectory_id: The session ID used as the filename (without .json).

    Returns:
        A dict containing the full trajectory data including steps and metadata.

    Raises:
        FileNotFoundError: If no trajectory file exists for the given ID.
    """
    trajectory_dir = _get_trajectory_dir()
    file_path = trajectory_dir / f"{trajectory_id}.json"
    # Let FileNotFoundError propagate naturally if file does not exist
    return json.loads(file_path.read_text())  # type: ignore[no-any-return]

record_trajectory

record_trajectory(
    session_id: str,
    steps: list[dict[str, Any]],
    outcome: Optional[str] = None,
    *,
    task_description: str = "",
    task_type: str = "",
    adapter_ids: list[str] | None = None,
) -> dict[str, Any]

Persist a coding session trajectory to disk for future distillation.

Parameters:

Name Type Description Default
session_id str

Unique identifier for the coding session.

required
steps list[dict[str, Any]]

List of step dicts, each containing attempt results.

required
outcome Optional[str]

Final session result ('success', 'exhausted', or None).

None
task_description str

Natural language description of the coding task.

''
task_type str

Category of task (e.g. 'function', 'class', 'refactor').

''
adapter_ids list[str] | None

LoRA adapter IDs used during the session.

None

Returns:

Type Description
dict[str, Any]

A dict with 'session_id' and 'file_path' keys.

Source code in libs/model-training/src/model_training/trajectory.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def record_trajectory(
    session_id: str,
    steps: list[dict[str, Any]],
    outcome: Optional[str] = None,
    *,
    task_description: str = "",
    task_type: str = "",
    adapter_ids: list[str] | None = None,
) -> dict[str, Any]:
    """Persist a coding session trajectory to disk for future distillation.

    Args:
        session_id: Unique identifier for the coding session.
        steps: List of step dicts, each containing attempt results.
        outcome: Final session result ('success', 'exhausted', or None).
        task_description: Natural language description of the coding task.
        task_type: Category of task (e.g. 'function', 'class', 'refactor').
        adapter_ids: LoRA adapter IDs used during the session.

    Returns:
        A dict with 'session_id' and 'file_path' keys.
    """
    trajectory_dir = _get_trajectory_dir()
    trajectory_dir.mkdir(parents=True, exist_ok=True)

    file_path = trajectory_dir / f"{session_id}.json"
    timestamp = datetime.now(tz=timezone.utc).isoformat()

    trajectory: dict[str, Any] = {
        "session_id": session_id,
        "task_description": task_description,
        "task_type": task_type,
        "adapter_ids": adapter_ids if adapter_ids is not None else [],
        "outcome": outcome,
        "timestamp": timestamp,
        "steps": steps,
    }

    file_path.write_text(json.dumps(trajectory, indent=2))

    return {"session_id": session_id, "file_path": str(file_path)}

Modules

config

Training configuration for LoRA fine-tuning.

No GPU imports required — this module is pure dict construction and validation.

Functions
get_training_config
get_training_config(
    task_type: str,
    rank: int = 64,
    epochs: int = 3,
    learning_rate: float = 0.0002,
) -> dict[str, Any]

Return a training configuration dict with hyperparameters.

Generates a configuration appropriate for the given task type, with sensible defaults for QLoRA fine-tuning.

Parameters:

Name Type Description Default
task_type str

Task category (e.g. 'bug-fix', 'feature-impl').

required
rank int

LoRA rank for the adapter.

64
epochs int

Number of training epochs.

3
learning_rate float

Learning rate for the optimizer.

0.0002

Returns:

Type Description
dict[str, Any]

A dict containing all training hyperparameters.

Example

config = get_training_config("bug-fix", rank=64, epochs=3) config["task_type"] 'bug-fix'

Source code in libs/model-training/src/model_training/config.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def get_training_config(
    task_type: str,
    rank: int = 64,
    epochs: int = 3,
    learning_rate: float = 2e-4,
) -> dict[str, Any]:
    """Return a training configuration dict with hyperparameters.

    Generates a configuration appropriate for the given task type,
    with sensible defaults for QLoRA fine-tuning.

    Args:
        task_type: Task category (e.g. 'bug-fix', 'feature-impl').
        rank: LoRA rank for the adapter.
        epochs: Number of training epochs.
        learning_rate: Learning rate for the optimizer.

    Returns:
        A dict containing all training hyperparameters.

    Example:
        >>> config = get_training_config("bug-fix", rank=64, epochs=3)
        >>> config["task_type"]
        'bug-fix'
    """
    return {
        "task_type": task_type,
        "rank": rank,
        "alpha": 2 * rank,
        "epochs": epochs,
        "learning_rate": learning_rate,
        "warmup_ratio": 0.03,
        "lr_scheduler_type": "cosine",
        "bf16": True,
        "per_device_train_batch_size": 1,
        "gradient_accumulation_steps": 4,
        "save_strategy": "no",
        "logging_steps": 1,
        "report_to": "none",
        "eval_strategy": "no",
        "target_modules": ["q_proj", "v_proj"],
        "dropout": 0.1,
    }
validate_config
validate_config(config: dict[str, Any]) -> bool

Validate training configuration fields and value ranges.

Checks that all required fields are present and their values fall within acceptable ranges.

Parameters:

Name Type Description Default
config dict[str, Any]

A training configuration dict to validate.

required

Returns:

Type Description
bool

True if the configuration is valid.

Raises:

Type Description
ValueError

If required keys are missing or values are out of range.

Example

valid = validate_config({"task_type": "bug-fix", "rank": 64, ... "epochs": 3, "learning_rate": 2e-4}) valid True

Source code in libs/model-training/src/model_training/config.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def validate_config(config: dict[str, Any]) -> bool:
    """Validate training configuration fields and value ranges.

    Checks that all required fields are present and their values
    fall within acceptable ranges.

    Args:
        config: A training configuration dict to validate.

    Returns:
        True if the configuration is valid.

    Raises:
        ValueError: If required keys are missing or values are out of range.

    Example:
        >>> valid = validate_config({"task_type": "bug-fix", "rank": 64,
        ...                         "epochs": 3, "learning_rate": 2e-4})
        >>> valid
        True
    """
    missing = _REQUIRED_KEYS - set(config.keys())
    if missing:
        raise ValueError(f"Missing required config keys: {sorted(missing)}")

    rank = config["rank"]
    if not (isinstance(rank, int) and rank > 0 and rank <= 256):
        raise ValueError(f"rank must be an integer in range (0, 256], got {rank!r}")

    epochs = config["epochs"]
    if not (isinstance(epochs, int) and epochs > 0 and epochs <= 100):
        raise ValueError(f"epochs must be an integer in range (0, 100], got {epochs!r}")

    lr = config["learning_rate"]
    if not (isinstance(lr, float) and lr > 0.0 and lr < 1.0):
        raise ValueError(f"learning_rate must be a float in range (0, 1), got {lr!r}")

    return True

d2l_config

Config helpers for Qwen3-Coder-Next hypernetwork training.

Requires transformers>=5.0 for Qwen3NextConfig (hybrid linear/full attention architecture). transformers 5.3.0 is installed in this project.

All heavy imports (transformers, ctx_to_lora, peft) are deferred to function bodies per project convention (INFRA-05) to avoid GPU imports at module level.

Functions
get_d2l_qwen3_config
get_d2l_qwen3_config() -> dict[str, Any]

Return Qwen3-Coder-Next architecture dimensions without loading model weights.

Uses Qwen3NextConfig defaults which exactly match Qwen3-Coder-Next specs: - hidden_size: 2048 - num_hidden_layers: 48 (12 full_attention + 36 linear_attention) - num_attention_heads: 16 (Q heads), num_key_value_heads: 2 (GQA KV) - head_dim: 256 - full_attention layer indices: [3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47] - vocab_size: 151936 - model_type: "qwen3_next"

Returns:

Type Description
dict[str, Any]

Dict with keys: hidden_size, num_hidden_layers, num_attention_heads,

dict[str, Any]

num_key_value_heads, head_dim, attention_layer_indices, vocab_size,

dict[str, Any]

model_type.

Source code in libs/model-training/src/model_training/d2l_config.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_d2l_qwen3_config() -> dict[str, Any]:
    """Return Qwen3-Coder-Next architecture dimensions without loading model weights.

    Uses Qwen3NextConfig defaults which exactly match Qwen3-Coder-Next specs:
    - hidden_size: 2048
    - num_hidden_layers: 48 (12 full_attention + 36 linear_attention)
    - num_attention_heads: 16 (Q heads), num_key_value_heads: 2 (GQA KV)
    - head_dim: 256
    - full_attention layer indices: [3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47]
    - vocab_size: 151936
    - model_type: "qwen3_next"

    Returns:
        Dict with keys: hidden_size, num_hidden_layers, num_attention_heads,
        num_key_value_heads, head_dim, attention_layer_indices, vocab_size,
        model_type.
    """
    from transformers import Qwen3NextConfig  # noqa: PLC0415

    cfg = Qwen3NextConfig()
    layer_types: list[str] = cfg.layer_types or []
    attention_layer_indices = [
        i for i, t in enumerate(layer_types) if t == "full_attention"
    ]
    return {
        "hidden_size": cfg.hidden_size,
        "num_hidden_layers": cfg.num_hidden_layers,
        "num_attention_heads": cfg.num_attention_heads,
        "num_key_value_heads": cfg.num_key_value_heads,
        "head_dim": cfg.head_dim,
        "attention_layer_indices": attention_layer_indices,
        "vocab_size": cfg.vocab_size,
        "model_type": cfg.model_type,
    }
build_qwen3_hypernet_config
build_qwen3_hypernet_config(
    lora_r: int = 8,
    target_modules: list[str] | None = None,
    aggregator_config: Any = None,
) -> Any

Construct HypernetConfig targeting Qwen3-Coder-Next attention layers.

Discovers full_attention layer indices dynamically from Qwen3NextConfig.layer_types. Result has exactly 12 layer indices matching the Qwen3-Coder-Next architecture.

Phase 26 probe cache integration: if a probe cache exists for QWEN3_NEXT_CANONICAL_NAME, uses real per-projection in/out dimensions for feature_sizes. Falls back to hidden_size placeholder when no cache is found (e.g., in CI where the model has not been probed).

Parameters:

Name Type Description Default
lora_r int

LoRA rank for the adapter. Defaults to 8.

8
target_modules list[str] | None

LoRA target module names. Defaults to ["q_proj", "v_proj"].

None
aggregator_config Any

Perceiver aggregator config from a Sakana checkpoint. If None (default / Phase 25 CI), HypernetConfig is built with aggregator_config=None as placeholder. Phase 29 populates this via get_aggregator_config() with a loaded model.

None

Returns:

Type Description
Any

HypernetConfig with layer_indices set to the 12 full_attention indices

Any

and base_hidden_size=2048.

Source code in libs/model-training/src/model_training/d2l_config.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def build_qwen3_hypernet_config(
    lora_r: int = 8,
    target_modules: list[str] | None = None,
    aggregator_config: Any = None,
) -> Any:
    """Construct HypernetConfig targeting Qwen3-Coder-Next attention layers.

    Discovers full_attention layer indices dynamically from Qwen3NextConfig.layer_types.
    Result has exactly 12 layer indices matching the Qwen3-Coder-Next architecture.

    Phase 26 probe cache integration: if a probe cache exists for
    QWEN3_NEXT_CANONICAL_NAME, uses real per-projection in/out dimensions for
    feature_sizes. Falls back to hidden_size placeholder when no cache is found
    (e.g., in CI where the model has not been probed).

    Args:
        lora_r: LoRA rank for the adapter. Defaults to 8.
        target_modules: LoRA target module names. Defaults to ["q_proj", "v_proj"].
        aggregator_config: Perceiver aggregator config from a Sakana checkpoint.
            If None (default / Phase 25 CI), HypernetConfig is built with
            aggregator_config=None as placeholder. Phase 29 populates this via
            get_aggregator_config() with a loaded model.

    Returns:
        HypernetConfig with layer_indices set to the 12 full_attention indices
        and base_hidden_size=2048.
    """
    from ctx_to_lora.modeling.hypernet import HypernetConfig  # noqa: PLC0415
    from peft import LoraConfig  # noqa: PLC0415
    from transformers import Qwen3NextConfig  # noqa: PLC0415

    if target_modules is None:
        target_modules = ["q_proj", "v_proj"]

    cfg = Qwen3NextConfig()
    layer_types: list[str] = cfg.layer_types or []
    layer_indices = [
        i for i, t in enumerate(layer_types) if t == "full_attention"
    ]  # Always [3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47]

    lora_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_r * 2,
        target_modules=target_modules,
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )

    from model_training.d2l_probe import (  # noqa: PLC0415
        QWEN3_NEXT_CANONICAL_NAME,
        load_probe_cache,
    )

    cache = load_probe_cache(QWEN3_NEXT_CANONICAL_NAME)
    if cache is not None:
        in_sizes = {mod: cache["feature_sizes"][mod]["in"] for mod in target_modules}
        out_sizes = {mod: cache["feature_sizes"][mod]["out"] for mod in target_modules}
        feature_sizes: tuple[dict[str, int], dict[str, int]] = (in_sizes, out_sizes)
        logger.info("Using probe cache feature_sizes for %s", QWEN3_NEXT_CANONICAL_NAME)
    else:
        hidden: int = cfg.hidden_size or 2048
        _placeholder: dict[str, int] = dict.fromkeys(target_modules, hidden)
        feature_sizes = (_placeholder, dict.fromkeys(target_modules, hidden))
        logger.warning(
            "No probe cache for '%s' — using hidden_size=%d as placeholder. "
            "Run probe_model() and save_probe_cache() to set real dimensions.",
            QWEN3_NEXT_CANONICAL_NAME,
            cfg.hidden_size,
        )

    return HypernetConfig(
        latent_size=512,
        use_light_weight_lora=False,
        light_weight_latent_size=128,
        per_rank_gen=False,
        use_per_rank_bias=False,
        use_bias=True,
        per_layer_processing=False,
        use_token_mixing=False,
        num_pre_head_layers=1,
        dropout_rate=0.0,
        lora_config=lora_config,
        extra_modules=None,
        base_hidden_size=cfg.hidden_size,
        layer_indices=layer_indices,
        feature_sizes=feature_sizes,
        aggregator_config=aggregator_config,
    )

d2l_data

Data pipeline for KL-divergence context distillation training.

Provides functions for: - Converting trajectories to distillation records (activation/teacher split) - Generating needle-in-haystack synthetic datasets for CI smoke testing - JSONL persistence (save/load round-trip) - Task-ID-level train/test splitting

Functions
save_jsonl
save_jsonl(
    records: list[dict[str, Any]], path: str | Path
) -> None

Persist a list of dicts to a JSONL file.

Parameters:

Name Type Description Default
records list[dict[str, Any]]

List of JSON-serializable dicts to save.

required
path str | Path

File path for output. Parent directories are created if needed.

required
Source code in libs/model-training/src/model_training/d2l_data.py
256
257
258
259
260
261
262
263
264
265
266
267
def save_jsonl(records: list[dict[str, Any]], path: str | Path) -> None:
    """Persist a list of dicts to a JSONL file.

    Args:
        records: List of JSON-serializable dicts to save.
        path: File path for output. Parent directories are created if needed.
    """
    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    with out.open("w", encoding="utf-8") as f:
        for record in records:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")
load_jsonl
load_jsonl(path: str | Path) -> list[dict[str, Any]]

Load records from a JSONL file.

Parameters:

Name Type Description Default
path str | Path

Path to the JSONL file.

required

Returns:

Type Description
list[dict[str, Any]]

List of dicts, one per non-empty line.

Source code in libs/model-training/src/model_training/d2l_data.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def load_jsonl(path: str | Path) -> list[dict[str, Any]]:
    """Load records from a JSONL file.

    Args:
        path: Path to the JSONL file.

    Returns:
        List of dicts, one per non-empty line.
    """
    src = Path(path)
    return [
        json.loads(line)
        for line in src.read_text(encoding="utf-8").splitlines()
        if line.strip()
    ]
format_for_distillation
format_for_distillation(
    trajectory: dict[str, Any],
) -> list[dict[str, str]]

Convert a trajectory to distillation records with activation/teacher split.

Each record has: - activation_text: trajectory context + task description (NO answer tokens) - teacher_text: trajectory context + task description + answer - task_id: identifier for train/test splitting

Only successful trajectories (outcome == 'success') produce records.

Parameters:

Name Type Description Default
trajectory dict[str, Any]

Trajectory dict with task_id/session_id, task_description, steps, and outcome fields.

required

Returns:

Type Description
list[dict[str, str]]

List of distillation record dicts, or empty list if no successful outcome.

Source code in libs/model-training/src/model_training/d2l_data.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def format_for_distillation(trajectory: dict[str, Any]) -> list[dict[str, str]]:
    """Convert a trajectory to distillation records with activation/teacher split.

    Each record has:
    - activation_text: trajectory context + task description (NO answer tokens)
    - teacher_text: trajectory context + task description + answer
    - task_id: identifier for train/test splitting

    Only successful trajectories (outcome == 'success') produce records.

    Args:
        trajectory: Trajectory dict with task_id/session_id, task_description,
            steps, and outcome fields.

    Returns:
        List of distillation record dicts, or empty list if no successful outcome.
    """
    if trajectory.get("outcome") != "success":
        return []

    steps: list[dict[str, Any]] = trajectory.get("steps", [])
    task_description: str = trajectory.get("task_description", "")
    task_id: str = trajectory.get("task_id") or trajectory.get("session_id", "")

    # Build trajectory context from step descriptions (no answer tokens)
    trajectory_parts: list[str] = []
    for step in steps:
        desc = step.get("description", "")
        if desc:
            trajectory_parts.append(desc)

    trajectory_text = "\n".join(trajectory_parts)

    # Build activation_text (context only, no answer)
    activation_base = f"{trajectory_text}\n{task_description}".strip()

    records: list[dict[str, str]] = []
    for step in steps:
        if not step.get("tests_passed"):
            continue

        # Extract answer from canonical_solution or generated_code
        answer: str = step.get("canonical_solution") or step.get("generated_code", "")
        if not answer:
            continue

        records.append(
            {
                "task_id": task_id,
                "activation_text": activation_base,
                "teacher_text": f"{activation_base}\n{answer}",
            }
        )

    return records
normalize_mined_trajectory
normalize_mined_trajectory(
    mined: dict[str, Any],
) -> dict[str, Any]

Convert a GitHub-mined trajectory dict into distillation-ready format.

Maps the mining pipeline's output (PR/issue metadata with commit and review steps) into the schema expected by format_for_distillation.

Outcome mapping
  • pr_* task_ids: "merged" -> "success", else "failure"
  • issue_* task_ids: "closed" -> "success", else "failure"
  • Fallback: "merged" or "closed" -> "success", else "failure"
Step mapping
  • Commit steps get [Commit] prefix in description and their content as generated_code.
  • Review steps get [Review] prefix with content inlined into description and empty generated_code.
  • Only the last commit step receives tests_passed=True and canonical_solution (and only when the overall outcome is success).

Parameters:

Name Type Description Default
mined dict[str, Any]

Dict from the GitHub mining pipeline with keys task_id, task_description, outcome, and steps (list of dicts with type, description, and content fields).

required

Returns:

Type Description
dict[str, Any]

Trajectory dict ready for format_for_distillation.

Source code in libs/model-training/src/model_training/d2l_data.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def normalize_mined_trajectory(mined: dict[str, Any]) -> dict[str, Any]:
    """Convert a GitHub-mined trajectory dict into distillation-ready format.

    Maps the mining pipeline's output (PR/issue metadata with commit and review
    steps) into the schema expected by ``format_for_distillation``.

    Outcome mapping:
        - ``pr_*`` task_ids: ``"merged"`` -> ``"success"``, else ``"failure"``
        - ``issue_*`` task_ids: ``"closed"`` -> ``"success"``, else ``"failure"``
        - Fallback: ``"merged"`` or ``"closed"`` -> ``"success"``, else ``"failure"``

    Step mapping:
        - Commit steps get ``[Commit]`` prefix in description and their content
          as ``generated_code``.
        - Review steps get ``[Review]`` prefix with content inlined into
          description and empty ``generated_code``.
        - Only the **last** commit step receives ``tests_passed=True`` and
          ``canonical_solution`` (and only when the overall outcome is success).

    Args:
        mined: Dict from the GitHub mining pipeline with keys ``task_id``,
            ``task_description``, ``outcome``, and ``steps`` (list of dicts
            with ``type``, ``description``, and ``content`` fields).

    Returns:
        Trajectory dict ready for ``format_for_distillation``.
    """
    task_id: str = mined.get("task_id", "")
    raw_outcome: str = mined.get("outcome", "")
    raw_steps: list[dict[str, Any]] = mined.get("steps", [])

    # --- Determine normalized outcome ---
    if task_id.startswith("pr_"):
        normalized_outcome = "success" if raw_outcome == "merged" else "failure"
    elif task_id.startswith("issue_"):
        normalized_outcome = "success" if raw_outcome == "closed" else "failure"
    else:
        normalized_outcome = (
            "success" if raw_outcome in ("merged", "closed") else "failure"
        )

    # --- Identify commit steps to find the last one ---
    commit_steps = [s for s in raw_steps if s.get("type") == "commit"]

    # --- Normalize steps ---
    normalized_steps: list[dict[str, Any]] = []
    for step in raw_steps:
        step_type = step.get("type", "")
        step_description = step.get("description", "")
        step_content = step.get("content", "")

        if step_type == "commit":
            is_last_commit = commit_steps and step is commit_steps[-1]
            tests_passed = is_last_commit and normalized_outcome == "success"
            entry: dict[str, Any] = {
                "description": f"[Commit] {step_description}",
                "generated_code": step_content,
                "tests_passed": tests_passed,
            }
            if is_last_commit and normalized_outcome == "success":
                entry["canonical_solution"] = step_content
            normalized_steps.append(entry)
        elif step_type == "review":
            normalized_steps.append(
                {
                    "description": f"[Review] {step_content}",
                    "generated_code": "",
                    "tests_passed": False,
                }
            )

    return {
        "task_id": task_id,
        "session_id": task_id,
        "task_description": mined.get("task_description", ""),
        "steps": normalized_steps,
        "outcome": normalized_outcome,
    }
generate_needle_dataset
generate_needle_dataset(
    n: int = 20,
) -> list[dict[str, str]]

Generate needle-in-haystack records for CI smoke testing.

Records are deterministic (no randomness, no LLM). Each record contains a code fact embedded in a function/class context with a query and answer.

Parameters:

Name Type Description Default
n int

Number of records to generate. Cycles through templates if n exceeds the number of available templates.

20

Returns:

Type Description
list[dict[str, str]]

List of n record dicts with activation_text, teacher_text, and task_id.

Source code in libs/model-training/src/model_training/d2l_data.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def generate_needle_dataset(n: int = 20) -> list[dict[str, str]]:
    """Generate needle-in-haystack records for CI smoke testing.

    Records are deterministic (no randomness, no LLM). Each record contains
    a code fact embedded in a function/class context with a query and answer.

    Args:
        n: Number of records to generate. Cycles through templates if n exceeds
           the number of available templates.

    Returns:
        List of n record dicts with activation_text, teacher_text, and task_id.
    """
    records: list[dict[str, str]] = []
    n_templates = len(_NEEDLE_TEMPLATES)
    n_slots = len(_SLOT_VALUES)

    for i in range(n):
        template = _NEEDLE_TEMPLATES[i % n_templates]
        slots = _SLOT_VALUES[i % n_slots]

        try:
            trajectory = template["trajectory_template"].format(**slots)
            query = template["query_template"].format(**slots)
            answer = template["answer_template"].format(**slots)
        except KeyError:
            # If template has a slot not in _SLOT_VALUES, use a safe fallback
            trajectory = (
                f"# code fact {i}\ndef func_{i}(x: int) -> int:\n    return {i}"
            )
            query = f"What does func_{i} return?"
            answer = str(i)

        activation_text = f"{trajectory}\n\nQ: {query}"
        teacher_text = f"{trajectory}\n\nQ: {query}\nA: {answer}"

        records.append(
            {
                "task_id": f"needle_{i}",
                "activation_text": activation_text,
                "teacher_text": teacher_text,
            }
        )

    return records
generate_trajectory_dataset
generate_trajectory_dataset(
    source: str = "humaneval", max_tasks: int | None = None
) -> list[dict[str, str]]

Generate trajectory records from a coding task dataset.

Each record has activation_text (prompt only, no solution) and teacher_text (prompt + canonical solution), making it ready for KL-divergence distillation.

Parameters:

Name Type Description Default
source str

Dataset source identifier. Currently only "humaneval" is supported.

'humaneval'
max_tasks int | None

Maximum number of tasks to process. If None, all tasks are used.

None

Returns:

Type Description
list[dict[str, str]]

List of trajectory record dicts with task_id, activation_text, and

list[dict[str, str]]

teacher_text fields.

Raises:

Type Description
ValueError

If an unsupported source is specified.

Source code in libs/model-training/src/model_training/d2l_data.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
def generate_trajectory_dataset(
    source: str = "humaneval",
    max_tasks: int | None = None,
) -> list[dict[str, str]]:
    """Generate trajectory records from a coding task dataset.

    Each record has activation_text (prompt only, no solution) and teacher_text
    (prompt + canonical solution), making it ready for KL-divergence distillation.

    Args:
        source: Dataset source identifier. Currently only "humaneval" is supported.
        max_tasks: Maximum number of tasks to process. If None, all tasks are used.

    Returns:
        List of trajectory record dicts with task_id, activation_text, and
        teacher_text fields.

    Raises:
        ValueError: If an unsupported source is specified.
    """
    from datasets import load_dataset  # noqa: PLC0415

    if source == "humaneval":
        dataset = load_dataset("openai_humaneval", split="test", trust_remote_code=True)
    else:
        raise ValueError(
            f"Unsupported source: {source!r}. Only 'humaneval' is supported."
        )

    records: list[dict[str, str]] = []
    tasks = list(dataset) if max_tasks is None else list(dataset)[:max_tasks]
    for task in tasks:
        task_id: str = task["task_id"]
        prompt: str = task["prompt"]
        canonical_solution: str = task["canonical_solution"]

        # activation_text: the function signature + docstring (no solution)
        activation_text = prompt.rstrip()
        # teacher_text: prompt + solution (the full implementation)
        teacher_text = prompt.rstrip() + canonical_solution

        records.append(
            {
                "task_id": task_id,
                "activation_text": activation_text,
                "teacher_text": teacher_text,
            }
        )

    logger.info("Generated %d trajectory records from %r", len(records), source)
    return records
augment_trajectories
augment_trajectories(
    trajectories: list[dict[str, Any]],
    n_variants: int = 3,
    model: str = "qwen2.5-coder:1.5b",
    ollama_base_url: str | None = None,
) -> list[dict[str, Any]]

Produce LLM-augmented variants of trajectory records.

For each input trajectory, generates up to n_variants augmented records using an Ollama LLM. Augmented records always inherit the source task_id to preserve split integrity when mixed with originals.

Augmentation strategies (up to n_variants selected in order): 1. Paraphrase the task description 2. Reorder/drop steps in the trajectory 3. Rename variables throughout the trajectory

Parameters:

Name Type Description Default
trajectories list[dict[str, Any]]

List of trajectory record dicts with task_id, activation_text, and teacher_text fields.

required
n_variants int

Number of augmented variants to produce per trajectory. Maximum 3 (one per augmentation strategy).

3
model str

Ollama model identifier to use for augmentation.

'qwen2.5-coder:1.5b'
ollama_base_url str | None

Ollama base URL. Defaults to "http://localhost:11434".

None

Returns:

Type Description
list[dict[str, Any]]

List of augmented trajectory dicts. Each record has the same task_id

list[dict[str, Any]]

as its source trajectory, with LLM-generated activation_text and

list[dict[str, Any]]

teacher_text.

Source code in libs/model-training/src/model_training/d2l_data.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
def augment_trajectories(
    trajectories: list[dict[str, Any]],
    n_variants: int = 3,
    model: str = "qwen2.5-coder:1.5b",
    ollama_base_url: str | None = None,
) -> list[dict[str, Any]]:
    """Produce LLM-augmented variants of trajectory records.

    For each input trajectory, generates up to n_variants augmented records
    using an Ollama LLM. Augmented records always inherit the source task_id
    to preserve split integrity when mixed with originals.

    Augmentation strategies (up to n_variants selected in order):
    1. Paraphrase the task description
    2. Reorder/drop steps in the trajectory
    3. Rename variables throughout the trajectory

    Args:
        trajectories: List of trajectory record dicts with task_id,
            activation_text, and teacher_text fields.
        n_variants: Number of augmented variants to produce per trajectory.
            Maximum 3 (one per augmentation strategy).
        model: Ollama model identifier to use for augmentation.
        ollama_base_url: Ollama base URL. Defaults to "http://localhost:11434".

    Returns:
        List of augmented trajectory dicts. Each record has the same task_id
        as its source trajectory, with LLM-generated activation_text and
        teacher_text.
    """
    import asyncio  # noqa: PLC0415

    from inference.ollama_provider import OllamaProvider  # noqa: PLC0415

    base_url = ollama_base_url or "http://localhost:11434"
    provider = OllamaProvider(base_url=base_url)

    augmentation_prompts = [
        (
            "Paraphrase the following coding task description and trajectory, "
            "keeping the same logic but using different wording:\n\n{text}"
        ),
        (
            "Rewrite the following coding trajectory by reordering and slightly "
            "dropping some intermediate steps, while preserving the overall "
            "outcome:\n\n{text}"
        ),
        (
            "Rewrite the following code trajectory by renaming variables and "
            "function parameters to different names, keeping the same logic:"
            "\n\n{text}"
        ),
    ]

    async def _augment_one(
        trajectory: dict[str, Any],
        prov: Any,
        n: int,
        mdl: str,
    ) -> list[dict[str, Any]]:
        task_id: str = trajectory["task_id"]
        activation_text: str = trajectory.get("activation_text", "")
        teacher_text: str = trajectory.get("teacher_text", "")

        results: list[dict[str, Any]] = []
        strategies = augmentation_prompts[:n]
        for strategy_prompt in strategies:
            augment_prompt = strategy_prompt.format(text=teacher_text)
            result = await prov.generate(augment_prompt, mdl)
            augmented_teacher = result.text

            # Build augmented activation by stripping the solution portion
            # Use teacher text length ratio to approximate activation
            activation_prompt = strategy_prompt.format(text=activation_text)
            act_result = await prov.generate(activation_prompt, mdl)
            augmented_activation = act_result.text

            results.append(
                {
                    "task_id": task_id,  # CRITICAL: inherit source task_id
                    "activation_text": augmented_activation,
                    "teacher_text": augmented_teacher,
                }
            )
        return results

    async def _augment_all(
        trajs: list[dict[str, Any]],
        n: int,
        mdl: str,
    ) -> list[dict[str, Any]]:
        all_tasks = [_augment_one(t, provider, n, mdl) for t in trajs]
        nested = await asyncio.gather(*all_tasks)
        return [record for group in nested for record in group]

    augmented = asyncio.run(_augment_all(trajectories, n_variants, model))
    logger.info(
        "Augmented %d trajectories into %d records (n_variants=%d)",
        len(trajectories),
        len(augmented),
        n_variants,
    )
    return augmented
split_by_task_id
split_by_task_id(
    records: list[dict[str, Any]],
    test_fraction: float = 0.2,
    seed: int = 42,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]

Split records at task-ID boundary with no task_id crossing train/test.

Augmented records that share a task_id are always assigned to the same partition, preventing task-family leakage.

Parameters:

Name Type Description Default
records list[dict[str, Any]]

List of record dicts, each with a 'task_id' field.

required
test_fraction float

Fraction of unique task_ids to assign to test set. Minimum 1 task_id goes to test even if fraction rounds to 0.

0.2
seed int

Random seed for reproducible splits.

42

Returns:

Type Description
tuple[list[dict[str, Any]], list[dict[str, Any]]]

Tuple of (train_records, test_records) where task_ids never overlap.

Source code in libs/model-training/src/model_training/d2l_data.py
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
def split_by_task_id(
    records: list[dict[str, Any]],
    test_fraction: float = 0.2,
    seed: int = 42,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    """Split records at task-ID boundary with no task_id crossing train/test.

    Augmented records that share a task_id are always assigned to the same
    partition, preventing task-family leakage.

    Args:
        records: List of record dicts, each with a 'task_id' field.
        test_fraction: Fraction of unique task_ids to assign to test set.
            Minimum 1 task_id goes to test even if fraction rounds to 0.
        seed: Random seed for reproducible splits.

    Returns:
        Tuple of (train_records, test_records) where task_ids never overlap.
    """
    task_ids = sorted({r["task_id"] for r in records})
    rng = random.Random(seed)
    rng.shuffle(task_ids)
    n_test = max(1, int(len(task_ids) * test_fraction))
    test_ids = set(task_ids[:n_test])
    train = [r for r in records if r["task_id"] not in test_ids]
    test = [r for r in records if r["task_id"] in test_ids]
    return train, test
normalize_mined_pairs
normalize_mined_pairs(
    trajectory: dict[str, Any],
    compress: bool = True,
    max_diff_lines: int = 500,
    language: str | None = None,
) -> list[dict[str, Any]]

Convert a mined PR trajectory into per-step training pairs.

Each review-to-revision cycle becomes one training record with activation_text (task + current code + review feedback) and teacher_text (activation + revision diff). Compatible with augment_trajectories, split_by_task_id, and save_jsonl.

The algorithm groups contiguous commits and reviews into blocks, then pairs each reviews-block with the following commits-block. Multiple commits in a block: the last commit is used (the state the reviewer actually saw). Multiple reviews: concatenated.

Parameters:

Name Type Description Default
trajectory dict[str, Any]

Raw mined trajectory from mine_pr_diff_chains.

required
compress bool

Apply diff compression via compress_diff.

True
max_diff_lines int

Max lines per compressed diff.

500
language str | None

Language tag for metadata (from repos config).

None

Returns:

Type Description
list[dict[str, Any]]

List of training pair records with task_id, activation_text,

list[dict[str, Any]]

teacher_text, and metadata fields.

Source code in libs/model-training/src/model_training/d2l_data.py
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
def normalize_mined_pairs(
    trajectory: dict[str, Any],
    compress: bool = True,
    max_diff_lines: int = 500,
    language: str | None = None,
) -> list[dict[str, Any]]:
    """Convert a mined PR trajectory into per-step training pairs.

    Each review-to-revision cycle becomes one training record with
    activation_text (task + current code + review feedback) and
    teacher_text (activation + revision diff). Compatible with
    ``augment_trajectories``, ``split_by_task_id``, and ``save_jsonl``.

    The algorithm groups contiguous commits and reviews into blocks,
    then pairs each reviews-block with the following commits-block.
    Multiple commits in a block: the last commit is used (the state
    the reviewer actually saw). Multiple reviews: concatenated.

    Args:
        trajectory: Raw mined trajectory from ``mine_pr_diff_chains``.
        compress: Apply diff compression via ``compress_diff``.
        max_diff_lines: Max lines per compressed diff.
        language: Language tag for metadata (from repos config).

    Returns:
        List of training pair records with task_id, activation_text,
        teacher_text, and metadata fields.
    """
    task_id: str = trajectory.get("task_id", "")
    task_desc: str = trajectory.get("task_description", "")
    raw_steps: list[dict[str, Any]] = trajectory.get("steps", [])
    outcome: str = trajectory.get("outcome", "")

    if not raw_steps:
        return []

    blocks = _group_steps_into_blocks(raw_steps)

    def _diff(step: dict[str, Any]) -> str:
        raw = step.get("content", "")
        return compress_diff(raw, max_lines=max_diff_lines) if compress else raw

    def _record(idx: int, activation: str, teacher: str) -> dict[str, Any]:
        return _make_pair_record(task_id, outcome, language, idx, activation, teacher)

    records: list[dict[str, Any]] = []
    step_idx = 0
    prev_diff = ""
    bi = 0  # block index

    # --- Step 0: initial commits block ---
    if blocks[0][0] == "commit":
        last_commit = blocks[0][1][-1]
        diff = _diff(last_commit)
        activation = f"## Task\n{task_desc}"
        teacher = f"{activation}\n\n## Implementation\n{diff}"
        records.append(_record(step_idx, activation, teacher))
        prev_diff = diff
        step_idx += 1
        bi = 1

    # --- Subsequent (reviews, commits) pairs ---
    while bi < len(blocks) - 1:
        if blocks[bi][0] == "review" and blocks[bi + 1][0] == "commit":
            review_text = "\n\n".join(r.get("content", "") for r in blocks[bi][1])
            revision = _diff(blocks[bi + 1][1][-1])
            activation = f"## Task\n{task_desc}"
            if prev_diff:
                activation += f"\n\n## Current Code\n{prev_diff}"
            activation += f"\n\n## Review Feedback\n{review_text}"
            teacher = f"{activation}\n\n## Revision\n{revision}"
            records.append(_record(step_idx, activation, teacher))
            prev_diff = revision
            step_idx += 1
            bi += 2
        else:
            bi += 1

    return records

d2l_diff

RTK-style diff compression for training data preparation.

Filters irrelevant files (lockfiles, generated code, binary assets, build artifacts) and truncates large diffs to minimize token overhead in hypernetwork training pairs.

Operates on the concatenated diff format produced by mine_pr_diff_chains

--- src/main.py --- +real code --- package-lock.json --- +lockfile noise

Functions
compress_diff
compress_diff(content: str, max_lines: int = 500) -> str

Filter irrelevant files and truncate large diffs.

Parses the --- filename --- section format produced by mine_pr_diff_chains and removes lockfiles, generated code, binary files, and build artifacts.

Parameters:

Name Type Description Default
content str

Concatenated diff content with section headers.

required
max_lines int

Maximum total output lines.

500

Returns:

Type Description
str

Filtered and truncated diff string.

Source code in libs/model-training/src/model_training/d2l_diff.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def compress_diff(content: str, max_lines: int = 500) -> str:
    """Filter irrelevant files and truncate large diffs.

    Parses the ``--- filename ---`` section format produced by
    ``mine_pr_diff_chains`` and removes lockfiles, generated code,
    binary files, and build artifacts.

    Args:
        content: Concatenated diff content with section headers.
        max_lines: Maximum total output lines.

    Returns:
        Filtered and truncated diff string.
    """
    if not content:
        return content

    parts = _SECTION_RE.split(content)

    # No section headers found -- fall through to truncation only
    if len(parts) == 1:
        return _truncate(content, max_lines)

    # parts = [preamble, header1, body1, header2, body2, ...]
    kept: list[str] = []
    idx = 1  # skip preamble (parts[0])
    while idx < len(parts):
        header = parts[idx]
        body = parts[idx + 1] if idx + 1 < len(parts) else ""
        filename = header[4:-4]  # strip "--- " and " ---"
        if not _should_skip(filename):
            kept.append(header + body)
        idx += 2

    result = "".join(kept).strip()

    if not result:
        return ""

    return _truncate(result, max_lines)

d2l_lora

Functional LoRA injection via context manager for hypernetwork training.

Patches transformer attention projection modules with F.linear forward functions that carry live hypernetwork tensor graph nodes, preserving autograd continuity through A and B matrices back to the hypernetwork head.

Unlike PEFT's get_peft_model (which severs the autograd graph by copying tensors into new nn.Parameter objects), this approach uses closures over the original A/B tensors so that loss.backward() propagates gradients through the LoRA path all the way to the hypernetwork parameters.

All heavy GPU imports (torch, transformers) are deferred to function bodies per INFRA-05 project convention.

Functions
apply_functional_lora
apply_functional_lora(
    model: Any, lora_dict: dict[str, Any], hc: Any
) -> _FunctionalLoRAContext

Create a context manager that patches model with functional LoRA.

Usage

with apply_functional_lora(base_model, lora_dict, hc): output = base_model(input_ids) loss = criterion(output, target) loss.backward() # gradients flow through A/B to hypernetwork

Parameters:

Name Type Description Default
model Any

Base transformer model (nn.Module).

required
lora_dict dict[str, Any]

Dict from HyperLoRA.generate_weights(). Structure: lora_dict[proj_name]["A"] shape: (batch=1, n_layers, r, d_in) lora_dict[proj_name]["B"] shape: (batch=1, n_layers, r, d_out) Keys match hc.lora_config.target_modules. Batch dimension is always 1 (squeezed at index 0).

required
hc Any

HypernetConfig with attributes: hc.lora_config.target_modules: list of projection names hc.lora_config.r: LoRA rank hc.lora_config.lora_alpha: scaling numerator hc.layer_indices: list of absolute layer indices

required

Returns:

Type Description
_FunctionalLoRAContext

_FunctionalLoRAContext to use as context manager.

Source code in libs/model-training/src/model_training/d2l_lora.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def apply_functional_lora(
    model: Any, lora_dict: dict[str, Any], hc: Any
) -> _FunctionalLoRAContext:
    """Create a context manager that patches model with functional LoRA.

    Usage:
        with apply_functional_lora(base_model, lora_dict, hc):
            output = base_model(input_ids)
            loss = criterion(output, target)
        loss.backward()  # gradients flow through A/B to hypernetwork

    Args:
        model: Base transformer model (nn.Module).
        lora_dict: Dict from HyperLoRA.generate_weights(). Structure:
            lora_dict[proj_name]["A"] shape: (batch=1, n_layers, r, d_in)
            lora_dict[proj_name]["B"] shape: (batch=1, n_layers, r, d_out)
            Keys match hc.lora_config.target_modules.
            Batch dimension is always 1 (squeezed at index 0).
        hc: HypernetConfig with attributes:
            hc.lora_config.target_modules: list of projection names
            hc.lora_config.r: LoRA rank
            hc.lora_config.lora_alpha: scaling numerator
            hc.layer_indices: list of absolute layer indices

    Returns:
        _FunctionalLoRAContext to use as context manager.
    """
    return _FunctionalLoRAContext(model, lora_dict, hc)

d2l_mining

GitHub trajectory mining for coding session distillation.

Mines PR diff chains and issue-commit chains from GitHub repositories, producing trajectory dicts suitable for normalization and distillation. Designed to run on an L4 VM with network access and a GITHUB_TOKEN.

Classes
Functions
search_quality_prs
search_quality_prs(
    repo: str,
    max_results: int = 100,
    github_token: str | None = None,
    min_review_comments: int = 1,
    min_commits: int = 2,
    exclude_labels: list[str] | None = None,
) -> list[int]

Search for high-quality merged PRs using the GitHub Search API.

Pre-filters PRs by review approval, comment count, label exclusion, and minimum commit count to identify PRs with meaningful review trajectories suitable for distillation.

Parameters:

Name Type Description Default
repo str

GitHub repository in "owner/repo" format.

required
max_results int

Maximum number of qualifying PR numbers to return.

100
github_token str | None

Personal access token for GitHub API authentication.

None
min_review_comments int

Minimum number of comments for search query.

1
min_commits int

Minimum number of commits a PR must have.

2
exclude_labels list[str] | None

Labels to exclude. Defaults to common non-code labels.

None

Returns:

Type Description
list[int]

List of qualifying PR numbers.

Source code in libs/model-training/src/model_training/d2l_mining.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def search_quality_prs(
    repo: str,
    max_results: int = 100,
    github_token: str | None = None,
    min_review_comments: int = 1,
    min_commits: int = 2,
    exclude_labels: list[str] | None = None,
) -> list[int]:
    """Search for high-quality merged PRs using the GitHub Search API.

    Pre-filters PRs by review approval, comment count, label exclusion,
    and minimum commit count to identify PRs with meaningful review
    trajectories suitable for distillation.

    Args:
        repo: GitHub repository in "owner/repo" format.
        max_results: Maximum number of qualifying PR numbers to return.
        github_token: Personal access token for GitHub API authentication.
        min_review_comments: Minimum number of comments for search query.
        min_commits: Minimum number of commits a PR must have.
        exclude_labels: Labels to exclude. Defaults to common non-code labels.

    Returns:
        List of qualifying PR numbers.
    """
    client = GitHubClient(token=github_token)
    labels_to_exclude = (
        frozenset(exclude_labels)
        if exclude_labels is not None
        else _DEFAULT_EXCLUDE_LABELS
    )

    query = (
        f"repo:{repo} is:pr is:merged review:approved comments:>{min_review_comments}"
    )
    per_page = min(max_results, 100)
    pages_needed = math.ceil(max_results / 100)

    all_items: list[dict[str, Any]] = []
    for page in range(1, pages_needed + 1):
        data = client.get(
            "/search/issues",
            params={
                "q": query,
                "sort": "updated",
                "order": "desc",
                "per_page": per_page,
                "page": page,
            },
        )
        items = data.get("items", [])
        all_items.extend(items)
        if len(items) < per_page:
            break

    total = len(all_items)

    # Label filter
    after_label: list[dict[str, Any]] = []
    for item in all_items:
        item_labels = {lbl["name"] for lbl in item.get("labels", [])}
        if not item_labels & labels_to_exclude:
            after_label.append(item)

    # Commit count filter
    result: list[int] = []
    for item in after_label:
        pr_number = item["number"]
        detail = client.get(f"/repos/{repo}/pulls/{pr_number}")
        if detail.get("commits", 0) >= min_commits:
            result.append(pr_number)
        if len(result) >= max_results:
            break

    logger.info(
        "Search found %d candidates, %d after label filter, %d after commit filter",
        total,
        len(after_label),
        len(result),
    )
    return result
mine_pr_diff_chains
mine_pr_diff_chains(
    repo: str,
    max_prs: int = 100,
    github_token: str | None = None,
    pr_numbers: list[int] | None = None,
) -> list[dict[str, Any]]

Extract PR diff chains from a GitHub repository.

Each chain represents an iterative coding session: initial commit -> review comments -> revision commits. The resulting trajectory records capture the back-and-forth of code review as a multi-step improvement process, suitable for distillation.

Returns trajectory dicts with the following fields: - task_id: f"pr_{repo}_{pr_number}" - task_description: PR title concatenated with body text - steps: list of commit diffs and review comments in chronological order - outcome: "merged" or "closed" depending on PR final state

Parameters:

Name Type Description Default
repo str

GitHub repository in "owner/repo" format.

required
max_prs int

Maximum number of PRs to process. Defaults to 100.

100
github_token str | None

Personal access token for GitHub API authentication.

None
pr_numbers list[int] | None

Optional list of specific PR numbers to mine. When provided, skips the paginated PR list fetch and fetches each PR individually.

None

Returns:

Type Description
list[dict[str, Any]]

List of trajectory dicts representing PR diff chains.

Source code in libs/model-training/src/model_training/d2l_mining.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def mine_pr_diff_chains(
    repo: str,
    max_prs: int = 100,
    github_token: str | None = None,
    pr_numbers: list[int] | None = None,
) -> list[dict[str, Any]]:
    """Extract PR diff chains from a GitHub repository.

    Each chain represents an iterative coding session: initial commit ->
    review comments -> revision commits. The resulting trajectory records
    capture the back-and-forth of code review as a multi-step improvement
    process, suitable for distillation.

    Returns trajectory dicts with the following fields:
    - task_id: f"pr_{repo}_{pr_number}"
    - task_description: PR title concatenated with body text
    - steps: list of commit diffs and review comments in chronological order
    - outcome: "merged" or "closed" depending on PR final state

    Args:
        repo: GitHub repository in "owner/repo" format.
        max_prs: Maximum number of PRs to process. Defaults to 100.
        github_token: Personal access token for GitHub API authentication.
        pr_numbers: Optional list of specific PR numbers to mine. When
            provided, skips the paginated PR list fetch and fetches each
            PR individually.

    Returns:
        List of trajectory dicts representing PR diff chains.
    """
    client = GitHubClient(token=github_token)

    if pr_numbers is not None:
        prs = [client.get(f"/repos/{repo}/pulls/{n}") for n in pr_numbers[:max_prs]]
    else:
        prs = client.get_paginated(
            f"/repos/{repo}/pulls",
            params={"state": "closed", "sort": "updated", "direction": "desc"},
            max_pages=math.ceil(max_prs / 100),
        )
        prs = prs[:max_prs]

    if not prs:
        return []

    trajectories: list[dict[str, Any]] = []

    for pr in prs:
        pr_number = pr["number"]
        title = pr.get("title", "")
        body = pr.get("body", "") or ""

        commits = client.get_paginated(
            f"/repos/{repo}/pulls/{pr_number}/commits",
            max_pages=5,
        )
        reviews = client.get_paginated(
            f"/repos/{repo}/pulls/{pr_number}/comments",
            max_pages=3,
        )

        # Build timestamped steps for chronological interleaving
        timed_steps: list[tuple[str, dict[str, str]]] = []

        for commit in commits:
            sha = commit["sha"]
            msg = commit["commit"]["message"]
            ts = commit["commit"].get("committer", {}).get("date", "")
            detail = client.get(f"/repos/{repo}/commits/{sha}")
            files = detail.get("files", [])
            patches = []
            for f in files:
                patch = f.get("patch", "")
                if patch:
                    patches.append(f"--- {f['filename']} ---\n{patch}")
            timed_steps.append(
                (
                    ts,
                    {
                        "type": "commit",
                        "description": msg,
                        "content": "\n".join(patches),
                    },
                )
            )

        for comment in reviews:
            ts = comment.get("created_at", "")
            timed_steps.append(
                (
                    ts,
                    {
                        "type": "review",
                        "description": "Review comment",
                        "content": comment.get("body", ""),
                    },
                )
            )

        # Sort by timestamp so commits and reviews interleave chronologically
        timed_steps.sort(key=lambda x: x[0])
        steps = [step for _, step in timed_steps]

        outcome = "merged" if pr.get("merged_at") is not None else "closed"

        trajectories.append(
            {
                "task_id": f"pr_{repo}_{pr_number}",
                "task_description": f"{title}\n\n{body}".strip(),
                "steps": steps,
                "outcome": outcome,
            }
        )

    return trajectories
mine_issue_commit_chains
mine_issue_commit_chains(
    repo: str,
    max_issues: int = 100,
    github_token: str | None = None,
) -> list[dict[str, Any]]

Link GitHub issues to their fixing commits via commit message references.

Scans commit messages for "fixes #N", "closes #N", or "resolves #N" patterns to identify which commits address which issues. Groups linked commits as trajectory steps for distillation.

Returns trajectory dicts with the following fields: - task_id: f"issue_{repo}_{issue_number}" - task_description: issue title concatenated with body text - steps: list of commits referencing this issue in chronological order - outcome: "closed" or "open" from the issue state

Parameters:

Name Type Description Default
repo str

GitHub repository in "owner/repo" format.

required
max_issues int

Maximum number of issues to process. Defaults to 100.

100
github_token str | None

Personal access token for GitHub API authentication.

None

Returns:

Type Description
list[dict[str, Any]]

List of trajectory dicts representing issue-commit chains.

Source code in libs/model-training/src/model_training/d2l_mining.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def mine_issue_commit_chains(
    repo: str,
    max_issues: int = 100,
    github_token: str | None = None,
) -> list[dict[str, Any]]:
    """Link GitHub issues to their fixing commits via commit message references.

    Scans commit messages for "fixes #N", "closes #N", or "resolves #N"
    patterns to identify which commits address which issues. Groups linked
    commits as trajectory steps for distillation.

    Returns trajectory dicts with the following fields:
    - task_id: f"issue_{repo}_{issue_number}"
    - task_description: issue title concatenated with body text
    - steps: list of commits referencing this issue in chronological order
    - outcome: "closed" or "open" from the issue state

    Args:
        repo: GitHub repository in "owner/repo" format.
        max_issues: Maximum number of issues to process. Defaults to 100.
        github_token: Personal access token for GitHub API authentication.

    Returns:
        List of trajectory dicts representing issue-commit chains.
    """
    client = GitHubClient(token=github_token)

    raw_issues = client.get_paginated(
        f"/repos/{repo}/issues",
        params={"state": "all", "sort": "updated", "direction": "desc"},
        max_pages=math.ceil(max_issues / 100),
    )

    # Filter out pull requests (GitHub issues API includes PRs)
    issues = [i for i in raw_issues if not i.get("pull_request")]
    issues = issues[:max_issues]

    issue_numbers = {i["number"] for i in issues}
    issue_map = {i["number"]: i for i in issues}

    repo_commits = client.get_paginated(
        f"/repos/{repo}/commits",
        max_pages=10,
    )

    # Group commits by referenced issue number
    linked: dict[int, list[dict[str, Any]]] = {}
    for commit in repo_commits:
        msg = commit["commit"]["message"]
        refs = _FIXES_RE.findall(msg)
        for ref in refs:
            issue_num = int(ref)
            if issue_num in issue_numbers:
                linked.setdefault(issue_num, []).append(commit)

    trajectories: list[dict[str, Any]] = []

    for issue_num, commits in linked.items():
        issue = issue_map[issue_num]
        title = issue.get("title", "")
        body = issue.get("body", "") or ""

        steps: list[dict[str, str]] = []
        for commit in commits:
            sha = commit["sha"]
            msg = commit["commit"]["message"]
            detail = client.get(f"/repos/{repo}/commits/{sha}")
            files = detail.get("files", [])
            patches = []
            for f in files:
                patch = f.get("patch", "")
                if patch:
                    patches.append(f"--- {f['filename']} ---\n{patch}")
            steps.append(
                {
                    "type": "commit",
                    "description": msg,
                    "content": "\n".join(patches),
                }
            )

        trajectories.append(
            {
                "task_id": f"issue_{repo}_{issue_num}",
                "task_description": f"{title}\n\n{body}".strip(),
                "steps": steps,
                "outcome": issue.get("state", "open"),
            }
        )

    return trajectories

d2l_prep

Data preparation pipeline for context distillation training.

Converts raw trajectory JSON files into a training JSONL by calling format_for_distillation on each trajectory and persisting the resulting records via save_jsonl.

Usage (CLI): uv run python -m model_training.d2l_prep traj1.json traj2.json -o train.jsonl

Functions
prepare_training_jsonl
prepare_training_jsonl(
    input_paths: list[Path], output_path: Path
) -> int

Convert trajectory JSON files to a training JSONL.

Reads each input file, calls format_for_distillation on every trajectory, collects all returned records, and writes them to output_path via save_jsonl.

Failed trajectories (outcome != 'success') are filtered by format_for_distillation and produce zero records — they do not raise.

Parameters:

Name Type Description Default
input_paths list[Path]

Trajectory JSON files, each containing a single trajectory dict or a JSON array of trajectory dicts.

required
output_path Path

Destination JSONL file. Parent directories are created automatically. File is always written (may be empty).

required

Returns:

Type Description
int

Number of records written.

Source code in libs/model-training/src/model_training/d2l_prep.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def prepare_training_jsonl(
    input_paths: list[Path],
    output_path: Path,
) -> int:
    """Convert trajectory JSON files to a training JSONL.

    Reads each input file, calls format_for_distillation on every trajectory,
    collects all returned records, and writes them to output_path via save_jsonl.

    Failed trajectories (outcome != 'success') are filtered by
    format_for_distillation and produce zero records — they do not raise.

    Args:
        input_paths: Trajectory JSON files, each containing a single trajectory
            dict or a JSON array of trajectory dicts.
        output_path: Destination JSONL file. Parent directories are created
            automatically. File is always written (may be empty).

    Returns:
        Number of records written.
    """
    all_records: list[dict[str, Any]] = []

    for path in input_paths:
        trajectories = _load_trajectories(path)
        logger.info("Processing %s (%d trajectories)", path, len(trajectories))
        for traj in trajectories:
            records = format_for_distillation(traj)
            all_records.extend(records)
            logger.debug(
                "  trajectory %s%d records",
                traj.get("task_id", "<no-id>"),
                len(records),
            )

    save_jsonl(all_records, output_path)
    logger.info("Wrote %d records to %s", len(all_records), output_path)
    return len(all_records)

d2l_probe

Architecture probe and activation extraction for hypernetwork training.

Discovers standard attention layers (those with q_proj/k_proj/v_proj/o_proj children) via model.named_modules(), caches results to JSON, and provides extract_activations_with_model() that accepts a pre-loaded model and tokenizer.

Phase 26 purpose: eliminate hidden_size placeholders and per-call model loading. The probe becomes the single source of truth for layer indices and projection dimensions across the v7.0 pipeline.

All heavy GPU imports (torch, transformers) are deferred to function bodies per INFRA-05 project convention.

Functions
probe_model
probe_model(model: Any) -> dict[str, Any]

Probe a model's architecture to discover standard attention layers.

Iterates model.named_modules() to find layers that have all four attention projection children (q_proj, k_proj, v_proj, o_proj). DeltaNet and other linear-attention layers that lack these projections are skipped.

For each discovered attention layer, captures the in/out dimensions of q_proj, k_proj, v_proj, and o_proj weights.

Parameters:

Name Type Description Default
model Any

Any nn.Module (typically a transformer model).

required

Returns:

Type Description
dict[str, Any]

Dict with keys: - attention_layer_indices: sorted list of int layer indices - feature_sizes: dict mapping projection name to {"in": int, "out": int}

Source code in libs/model-training/src/model_training/d2l_probe.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def probe_model(model: Any) -> dict[str, Any]:
    """Probe a model's architecture to discover standard attention layers.

    Iterates model.named_modules() to find layers that have all four attention
    projection children (q_proj, k_proj, v_proj, o_proj). DeltaNet and other
    linear-attention layers that lack these projections are skipped.

    For each discovered attention layer, captures the in/out dimensions of
    q_proj, k_proj, v_proj, and o_proj weights.

    Args:
        model: Any nn.Module (typically a transformer model).

    Returns:
        Dict with keys:
            - attention_layer_indices: sorted list of int layer indices
            - feature_sizes: dict mapping projection name to {"in": int, "out": int}
    """
    attention_layer_indices: list[int] = []
    feature_sizes: dict[str, dict[str, int]] = {}

    for name, module in model.named_modules():
        # Get immediate children names
        child_names = {n.split(".")[-1] for n, _ in module.named_children()}
        if not ATTN_PROJECTIONS.issubset(child_names):
            continue

        # Extract layer index from the last numeric segment in dotted name
        parts = name.split(".")
        layer_idx: int | None = None
        for part in reversed(parts):
            if part.isdigit():
                layer_idx = int(part)
                break
        if layer_idx is None:
            continue

        attention_layer_indices.append(layer_idx)

        # Capture projection dimensions (out_f, in_f = weight.shape)
        for proj_name in ("q_proj", "k_proj", "v_proj", "o_proj"):
            proj = getattr(module, proj_name, None)
            if proj is None or not hasattr(proj, "weight"):
                continue
            out_f, in_f = proj.weight.shape
            if proj_name not in feature_sizes:
                feature_sizes[proj_name] = {"in": in_f, "out": out_f}

    return {
        "attention_layer_indices": sorted(attention_layer_indices),
        "feature_sizes": feature_sizes,
    }
save_probe_cache
save_probe_cache(
    model_name: str, probe_result: dict[str, Any]
) -> Path

Persist probe results to JSON cache.

Adds metadata fields (model_name, model_name_hash, probed_at) to a copy of probe_result. Creates PROBE_CACHE_DIR if it does not exist.

Parameters:

Name Type Description Default
model_name str

Canonical model identifier (used for cache lookup key).

required
probe_result dict[str, Any]

Output from probe_model().

required

Returns:

Type Description
Path

Path to the written JSON file.

Source code in libs/model-training/src/model_training/d2l_probe.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def save_probe_cache(model_name: str, probe_result: dict[str, Any]) -> Path:
    """Persist probe results to JSON cache.

    Adds metadata fields (model_name, model_name_hash, probed_at) to a copy
    of probe_result. Creates PROBE_CACHE_DIR if it does not exist.

    Args:
        model_name: Canonical model identifier (used for cache lookup key).
        probe_result: Output from probe_model().

    Returns:
        Path to the written JSON file.
    """
    cache_path = _model_name_to_cache_path(model_name)
    PROBE_CACHE_DIR.mkdir(parents=True, exist_ok=True)

    data = dict(probe_result)
    data["model_name"] = model_name
    data["model_name_hash"] = hashlib.sha256(model_name.encode()).hexdigest()[:16]
    data["probed_at"] = datetime.now(timezone.utc).isoformat()

    cache_path.write_text(json.dumps(data, indent=2))
    logger.info("Probe cache saved: %s%s", model_name, cache_path)
    return cache_path
load_probe_cache
load_probe_cache(model_name: str) -> dict[str, Any] | None

Load probe results from JSON cache.

Parameters:

Name Type Description Default
model_name str

Canonical model identifier.

required

Returns:

Type Description
dict[str, Any] | None

Probe result dict (including metadata fields) if cached, else None.

dict[str, Any] | None

Never raises — returns None on any miss.

Source code in libs/model-training/src/model_training/d2l_probe.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def load_probe_cache(model_name: str) -> dict[str, Any] | None:
    """Load probe results from JSON cache.

    Args:
        model_name: Canonical model identifier.

    Returns:
        Probe result dict (including metadata fields) if cached, else None.
        Never raises — returns None on any miss.
    """
    cache_path = _model_name_to_cache_path(model_name)
    if not cache_path.exists():
        logger.debug("Probe cache miss for '%s' (path: %s)", model_name, cache_path)
        return None

    data: dict[str, Any] = json.loads(cache_path.read_text())
    logger.debug("Probe cache hit for '%s'", model_name)
    return data
extract_activations_with_model
extract_activations_with_model(
    text: str,
    model: Any,
    tokenizer: Any,
    layer_indices: list[int] | None = None,
    model_name: str | None = None,
    max_length: int = 512,
) -> tuple[Any, Any]

Extract per-layer hidden state activations from a pre-loaded model.

Runs text through the model with output_hidden_states=True and stacks activations from the specified layer indices. Uses hidden_states[i] directly (no +1 offset) — consistent with existing sakana_d2l.py convention.

Parameters:

Name Type Description Default
text str

Input text to tokenize and process.

required
model Any

Pre-loaded nn.Module in eval mode.

required
tokenizer Any

Pre-loaded tokenizer.

required
layer_indices list[int] | None

Which hidden state indices to extract. If None, loads from probe cache via model_name.

None
model_name str | None

Canonical model name for cache lookup (required when layer_indices is None).

None
max_length int

Max token sequence length.

512

Returns:

Type Description
tuple[Any, Any]

Tuple of (features, attention_mask): features shape: (1, num_layers, seq_len, hidden_dim) attention_mask shape: (1, seq_len)

Raises:

Type Description
RuntimeError

If layer_indices is None and no probe cache exists for model_name.

Source code in libs/model-training/src/model_training/d2l_probe.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def extract_activations_with_model(
    text: str,
    model: Any,
    tokenizer: Any,
    layer_indices: list[int] | None = None,
    model_name: str | None = None,
    max_length: int = 512,
) -> tuple[Any, Any]:
    """Extract per-layer hidden state activations from a pre-loaded model.

    Runs text through the model with output_hidden_states=True and stacks
    activations from the specified layer indices. Uses hidden_states[i] directly
    (no +1 offset) — consistent with existing sakana_d2l.py convention.

    Args:
        text: Input text to tokenize and process.
        model: Pre-loaded nn.Module in eval mode.
        tokenizer: Pre-loaded tokenizer.
        layer_indices: Which hidden state indices to extract. If None, loads
            from probe cache via model_name.
        model_name: Canonical model name for cache lookup (required when
            layer_indices is None).
        max_length: Max token sequence length.

    Returns:
        Tuple of (features, attention_mask):
            features shape: (1, num_layers, seq_len, hidden_dim)
            attention_mask shape: (1, seq_len)

    Raises:
        RuntimeError: If layer_indices is None and no probe cache exists for model_name.
    """
    import torch  # noqa: PLC0415

    if layer_indices is None:
        if model_name is None:
            msg = (
                "layer_indices is None and model_name is None — "
                "cannot load from probe cache without a model name."
            )
            raise RuntimeError(msg)
        cache = load_probe_cache(model_name)
        if cache is None:
            msg = (
                f"layer_indices is None but no probe cache found for '{model_name}'. "
                "Run probe_model() and save_probe_cache() first."
            )
            raise RuntimeError(msg)
        layer_indices = cache["attention_layer_indices"]

    # Determine device from model
    try:
        device = next(model.parameters()).device
    except StopIteration:
        device = torch.device("cpu")

    inputs = tokenizer(
        text, return_tensors="pt", truncation=True, max_length=max_length
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    hidden_states = outputs.hidden_states
    # Stack selected layers: (batch, num_layers, seq_len, hidden_dim)
    selected = torch.stack([hidden_states[i] for i in layer_indices], dim=1)
    attention_mask = inputs["attention_mask"]

    logger.info(
        "Extracted activations: %s from %d layers", selected.shape, len(layer_indices)
    )
    return selected, attention_mask

d2l_train

KL-divergence context distillation training loop for Qwen3-Coder-Next.

Assembles all Phase 25-28 components (config, data pipeline, activation extraction, weight transfer, functional LoRA injection) into a complete distillation training script.

Three execution modes: - dry-run: loads real base model + hypernet, validates tensor shapes, exits - smoke-test: 5 training steps, verifies finite loss and decreasing trend - full: trains from JSONL dataset with tiered checkpointing and MLflow tracking

All heavy GPU imports (torch, transformers, peft) are deferred to function bodies per INFRA-05 project convention.

Usage

uv run python -m model_training.d2l_train --dry-run uv run python -m model_training.d2l_train --smoke-test uv run python -m model_training.d2l_train --dataset path/to/train.jsonl

Classes
D2LTrainConfig

Bases: BaseModel

Pydantic model for D2L training hyperparameters.

Enables validation, JSON serialization (for checkpoint storage), and .model_dump() for MLflow experiment logging.

Attributes:

Name Type Description
base_model_name str

HuggingFace model name for the student/teacher base.

sakana_checkpoint_path str

Path to the Sakana hypernet checkpoint.

num_steps int

Total training steps.

lr float

Learning rate for AdamW optimizer.

alpha float

Blending weight for KL vs CE loss (1.0 = pure KL, 0.0 = pure CE).

temperature float

Softmax temperature for KL divergence computation.

checkpoint_every int

Steps between lightweight checkpoint saves.

full_checkpoint_every int

Steps between full checkpoint saves (incl. optimizer).

checkpoint_dir str

Directory for checkpoint output.

experiment_name str

MLflow experiment name.

dry_run bool

If True, validate tensor shapes then exit.

smoke_test bool

If True, run 5 steps and verify loss trend.

dataset_path str | None

Path to training JSONL file (required for full training).

grad_clip float

Gradient clipping max norm.

warmup_steps int

Number of linear LR warmup steps.

lora_r int

LoRA rank.

max_length int

Maximum tokenizer sequence length.

Functions
train_d2l_qwen3
train_d2l_qwen3(config: D2LTrainConfig) -> dict[str, Any]

Run KL-divergence context distillation training.

Three execution modes controlled by config flags: - dry_run=True: Validate shapes with single forward pass, no optimizer step. - smoke_test=True: Run min(num_steps, 5) steps, assert finite decreasing loss. - default: Full training from dataset with checkpointing and MLflow tracking.

Parameters:

Name Type Description Default
config D2LTrainConfig

Training configuration.

required

Returns:

Type Description
dict[str, Any]

Dictionary with training results: - final_loss: Loss at the last step. - best_loss: Lowest loss seen during training. - num_steps_completed: Number of training steps completed. - checkpoint_dir: Path to checkpoint directory. - shape_summary (dry_run only): Tensor shape validation results.

Source code in libs/model-training/src/model_training/d2l_train.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def train_d2l_qwen3(config: D2LTrainConfig) -> dict[str, Any]:  # noqa: C901
    """Run KL-divergence context distillation training.

    Three execution modes controlled by config flags:
    - dry_run=True: Validate shapes with single forward pass, no optimizer step.
    - smoke_test=True: Run min(num_steps, 5) steps, assert finite decreasing loss.
    - default: Full training from dataset with checkpointing and MLflow tracking.

    Args:
        config: Training configuration.

    Returns:
        Dictionary with training results:
            - final_loss: Loss at the last step.
            - best_loss: Lowest loss seen during training.
            - num_steps_completed: Number of training steps completed.
            - checkpoint_dir: Path to checkpoint directory.
            - shape_summary (dry_run only): Tensor shape validation results.
    """
    import mlflow  # noqa: PLC0415
    import torch  # noqa: PLC0415
    from ctx_to_lora.modeling.hypernet import HyperLoRA  # noqa: PLC0415
    from torch.nn.utils import clip_grad_norm_  # noqa: PLC0415
    from torch.optim import AdamW  # noqa: PLC0415
    from torch.optim.lr_scheduler import (  # noqa: PLC0415
        CosineAnnealingLR,
        LinearLR,
        SequentialLR,
    )
    from transformers import AutoModelForCausalLM, AutoTokenizer  # noqa: PLC0415

    from model_training.d2l_config import build_qwen3_hypernet_config  # noqa: PLC0415
    from model_training.d2l_data import (  # noqa: PLC0415
        generate_needle_dataset,
        load_jsonl,
        split_by_task_id,
    )
    from model_training.d2l_probe import QWEN3_NEXT_CANONICAL_NAME  # noqa: PLC0415
    from model_training.sakana_d2l import (  # noqa: PLC0415
        get_aggregator_config,
        transfer_aggregator_weights,
    )

    # Mode dispatch: dry_run exits after shape validation
    if config.dry_run:
        shape_summary = _dry_run_validate_shapes(config)
        return {"shape_summary": shape_summary, "status": "dry_run_complete"}

    # Smoke test caps steps at 5
    if config.smoke_test:
        config = config.model_copy(update={"num_steps": min(config.num_steps, 5)})

    # Guard: require probe cache when not in smoke_test mode.
    # smoke_test uses generate_needle_dataset and may not have a real probe cache.
    if not config.smoke_test:
        _require_probe_cache(QWEN3_NEXT_CANONICAL_NAME)

    num_steps = config.num_steps
    warmup_steps = config.warmup_steps

    # Load tokenizer and base model
    logger.info("Loading tokenizer: %s", config.base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)

    logger.info("Loading base model: %s", config.base_model_name)
    base_model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name,
        output_hidden_states=True,
    ).eval()

    # Build hypernet config with aggregator config from checkpoint
    logger.info(
        "Building hypernet config from checkpoint: %s",
        config.sakana_checkpoint_path,
    )
    hc = build_qwen3_hypernet_config(
        aggregator_config=get_aggregator_config(config.sakana_checkpoint_path),
        lora_r=config.lora_r,
    )

    # Create hypernet and transfer aggregator weights (freezes aggregator)
    hypernet = HyperLoRA(hc).to(torch.float32)
    hypernet = transfer_aggregator_weights(hypernet, config.sakana_checkpoint_path)
    hypernet.train()

    # Device selection: cuda > mps > cpu
    from shared.hardware import get_best_device  # noqa: PLC0415

    device = torch.device(get_best_device())
    logger.info("Using device: %s", device)
    base_model = base_model.to(device)
    hypernet = hypernet.to(device)

    # Optimizer: only trainable params (head + projections, not frozen aggregator)
    trainable_params = [p for p in hypernet.parameters() if p.requires_grad]
    logger.info("Trainable params: %d", sum(p.numel() for p in trainable_params))
    optimizer = AdamW(trainable_params, lr=config.lr)

    # Scheduler: linear warmup → cosine annealing
    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=0.01, total_iters=warmup_steps),
            CosineAnnealingLR(
                optimizer,
                T_max=max(1, num_steps - warmup_steps),
                eta_min=1e-6,
            ),
        ],
        milestones=[warmup_steps],
    )

    # Data loading
    if config.smoke_test or config.dataset_path is None:
        records = generate_needle_dataset(n=20)
        logger.info("Using needle dataset (%d records)", len(records))
    else:
        all_records = load_jsonl(config.dataset_path)
        records, _ = split_by_task_id(all_records)
        logger.info(
            "Loaded %d training records from %s",
            len(records),
            config.dataset_path,
        )

    if not records:
        raise ValueError("No training records loaded; cannot train on empty dataset.")

    # MLflow setup and training loop
    _setup_mlflow(config)

    best_loss = float("inf")
    final_loss = float("inf")
    step_losses: list[float] = []

    with mlflow.start_run(run_name=f"{config.experiment_name}-step{num_steps}"):
        mlflow.log_params(config.model_dump())

        for step in range(1, num_steps + 1):
            record = records[(step - 1) % len(records)]

            loss, metrics = _training_step(
                record=record,
                base_model=base_model,
                tokenizer=tokenizer,
                hypernet=hypernet,
                hc=hc,
                config=config,
            )

            loss.backward()
            clip_grad_norm_(trainable_params, config.grad_clip)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            step_loss = metrics["total_loss"]
            step_losses.append(step_loss)
            final_loss = step_loss
            if step_loss < best_loss:
                best_loss = step_loss

            mlflow.log_metrics(metrics, step=step)
            logger.info(
                "Step %d/%d — loss=%.4f (kl=%.4f, ce=%.4f)",
                step,
                num_steps,
                metrics["total_loss"],
                metrics["kl_loss"],
                metrics["ce_loss"],
            )

            # Tiered checkpointing
            if step % config.full_checkpoint_every == 0:
                ckpt_path = _save_checkpoint(
                    step=step,
                    hypernet=hypernet,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    config=config,
                    hc=hc,
                    best_loss=best_loss,
                    full=True,
                )
                mlflow.log_artifact(str(ckpt_path))
            elif step % config.checkpoint_every == 0:
                ckpt_path = _save_checkpoint(
                    step=step,
                    hypernet=hypernet,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    config=config,
                    hc=hc,
                    best_loss=best_loss,
                    full=False,
                )
                mlflow.log_artifact(str(ckpt_path))

    # Smoke test assertions
    if config.smoke_test:
        for i, sl in enumerate(step_losses):
            assert torch.isfinite(torch.tensor(sl)), (  # noqa: S101
                f"Smoke test: loss at step {i + 1} is not finite: {sl}"
            )
        assert step_losses[-1] < step_losses[0], (  # noqa: S101
            f"Smoke test: final loss {step_losses[-1]:.4f} not less than "
            f"initial loss {step_losses[0]:.4f}"
        )
        assert any(p.grad is not None for p in trainable_params), (  # noqa: S101
            "Smoke test: no trainable param has non-None gradient after training"
        )

    return {
        "final_loss": final_loss,
        "best_loss": best_loss,
        "num_steps_completed": num_steps,
        "checkpoint_dir": config.checkpoint_dir,
    }

github_client

Thin GitHub REST API client with auth, pagination, and rate-limit retry.

Designed for batch data mining on a training VM (sync httpx is fine).

Classes
GitHubClient
GitHubClient(
    token: str | None = None,
    base_url: str = "https://api.github.com",
)

Minimal GitHub REST API client.

Handles authentication, paginated list endpoints, and automatic retry on rate-limit 403 responses.

Parameters:

Name Type Description Default
token str | None

GitHub personal access token. Optional for public endpoints but required for private repos and higher rate limits.

None
base_url str

API base URL. Override for GitHub Enterprise.

'https://api.github.com'

Initialize the client with optional auth token.

Parameters:

Name Type Description Default
token str | None

GitHub personal access token.

None
base_url str

API base URL.

'https://api.github.com'
Source code in libs/model-training/src/model_training/github_client.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    token: str | None = None,
    base_url: str = "https://api.github.com",
) -> None:
    """Initialize the client with optional auth token.

    Args:
        token: GitHub personal access token.
        base_url: API base URL.
    """
    self._base_url = base_url.rstrip("/")
    self._headers: dict[str, str] = {
        "Accept": "application/vnd.github+json",
        "X-GitHub-Api-Version": "2022-11-28",
    }
    if token is not None:
        self._headers["Authorization"] = f"Bearer {token}"
Functions
get
get(
    path: str,
    params: dict[str, Any] | None = None,
    max_retries: int = 3,
) -> Any

GET a single API endpoint with rate-limit retry.

Parameters:

Name Type Description Default
path str

API path relative to base_url (e.g. /repos/owner/repo).

required
params dict[str, Any] | None

Optional query parameters.

None
max_retries int

Maximum number of retries on rate-limit 403.

3

Returns:

Type Description
Any

Parsed JSON response body.

Raises:

Type Description
HTTPStatusError

On non-rate-limit error responses.

Source code in libs/model-training/src/model_training/github_client.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def get(
    self,
    path: str,
    params: dict[str, Any] | None = None,
    max_retries: int = 3,
) -> Any:
    """GET a single API endpoint with rate-limit retry.

    Args:
        path: API path relative to base_url (e.g. ``/repos/owner/repo``).
        params: Optional query parameters.
        max_retries: Maximum number of retries on rate-limit 403.

    Returns:
        Parsed JSON response body.

    Raises:
        httpx.HTTPStatusError: On non-rate-limit error responses.
    """
    url = f"{self._base_url}{path}"
    return self._get_response(url, params=params, max_retries=max_retries).json()
get_paginated
get_paginated(
    path: str,
    params: dict[str, Any] | None = None,
    max_pages: int = 10,
    per_page: int = 100,
) -> list[Any]

GET a paginated list endpoint, following Link rel=next headers.

Parameters:

Name Type Description Default
path str

API path relative to base_url.

required
params dict[str, Any] | None

Optional query parameters (per_page is injected).

None
max_pages int

Maximum number of pages to fetch.

10
per_page int

Items per page (max 100 for most GitHub endpoints).

100

Returns:

Type Description
list[Any]

Flat list of all items across all fetched pages.

Source code in libs/model-training/src/model_training/github_client.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def get_paginated(
    self,
    path: str,
    params: dict[str, Any] | None = None,
    max_pages: int = 10,
    per_page: int = 100,
) -> list[Any]:
    """GET a paginated list endpoint, following Link rel=next headers.

    Args:
        path: API path relative to base_url.
        params: Optional query parameters (``per_page`` is injected).
        max_pages: Maximum number of pages to fetch.
        per_page: Items per page (max 100 for most GitHub endpoints).

    Returns:
        Flat list of all items across all fetched pages.
    """
    merged_params: dict[str, Any] = dict(params or {})
    merged_params["per_page"] = per_page

    items: list[Any] = []
    url: str | None = f"{self._base_url}{path}"

    for _ in range(max_pages):
        if url is None:
            break
        resp = self._get_response(url, params=merged_params)
        items.extend(resp.json())

        # After the first request, params are baked into the Link URL.
        merged_params = {}

        link = resp.headers.get("Link", "")
        match = _LINK_NEXT_RE.search(link)
        url = match.group(1) if match else None

    return items

hypernetwork

DocToLoraHypernetwork: Perceiver-based instant LoRA adapter generation.

Generates rank-8 LoRA adapter weights from token IDs in a single forward pass. Distinct from the QLoRA gradient-descent path (Phase 21) — this produces adapters in <1s by cross-attending over token embeddings with learned latents.

IMPORTANT: All GPU imports (torch, safetensors) are deferred inside function/method bodies per INFRA-05 pattern — this module is importable in CPU-only CI.

Usage

from model_training.hypernetwork import ( DocToLoraHypernetwork, save_hypernetwork_adapter, )

model = DocToLoraHypernetwork(input_dim=DEFAULT_VOCAB_SIZE) weights = model(token_ids) save_hypernetwork_adapter(weights, "/tmp/adapter", "Qwen/Qwen2.5-Coder-7B")

Functions
load_pretrained
load_pretrained(
    checkpoint_path: str, device: str = "cpu", **kwargs: Any
) -> Any

Load a pretrained DocToLoraHypernetwork from a checkpoint.

Parameters:

Name Type Description Default
checkpoint_path str

Path to the .pt checkpoint file.

required
device str

Device to load onto ('cpu', 'cuda', 'mps'). Default: 'cpu'.

'cpu'
**kwargs Any

Override constructor args (input_dim, num_latents, etc.).

{}

Returns:

Type Description
Any

DocToLoraHypernetwork nn.Module loaded with pretrained weights.

Source code in libs/model-training/src/model_training/hypernetwork.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def load_pretrained(
    checkpoint_path: str,
    device: str = "cpu",
    **kwargs: Any,
) -> Any:
    """Load a pretrained DocToLoraHypernetwork from a checkpoint.

    Args:
        checkpoint_path: Path to the .pt checkpoint file.
        device: Device to load onto ('cpu', 'cuda', 'mps'). Default: 'cpu'.
        **kwargs: Override constructor args (input_dim, num_latents, etc.).

    Returns:
        DocToLoraHypernetwork nn.Module loaded with pretrained weights.
    """
    import torch  # noqa: PLC0415

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)

    # Extract constructor args from checkpoint or use defaults/overrides
    ctor_args = checkpoint.get("hypernetwork_config", {})
    ctor_args.update(kwargs)
    if "input_dim" not in ctor_args:
        ctor_args["input_dim"] = DEFAULT_VOCAB_SIZE

    model = DocToLoraHypernetwork(**ctor_args)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()
    return model
trajectory_to_tokens
trajectory_to_tokens(
    trajectory_text: str,
    vocab_size: int = DEFAULT_VOCAB_SIZE,
    max_length: int = 2048,
) -> "torch.Tensor"

Encode trajectory text as token IDs for the hypernetwork.

Uses a simple hash-based tokenization (character trigrams mapped to vocab indices). This is intentionally simple — the hypernetwork learns its own embedding, so exact tokenization doesn't matter as long as it's consistent.

Parameters:

Name Type Description Default
trajectory_text str

Text to encode (plan, code diffs, test results, etc.).

required
vocab_size int

Size of the hypernetwork's embedding vocabulary.

DEFAULT_VOCAB_SIZE
max_length int

Maximum sequence length (truncates or pads).

2048

Returns:

Type Description
'torch.Tensor'

Token ID tensor of shape (1, max_length) ready for hypernetwork forward().

Source code in libs/model-training/src/model_training/hypernetwork.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def trajectory_to_tokens(
    trajectory_text: str,
    vocab_size: int = DEFAULT_VOCAB_SIZE,
    max_length: int = 2048,
) -> "torch.Tensor":
    """Encode trajectory text as token IDs for the hypernetwork.

    Uses a simple hash-based tokenization (character trigrams mapped to
    vocab indices). This is intentionally simple — the hypernetwork learns
    its own embedding, so exact tokenization doesn't matter as long as
    it's consistent.

    Args:
        trajectory_text: Text to encode (plan, code diffs, test results, etc.).
        vocab_size: Size of the hypernetwork's embedding vocabulary.
        max_length: Maximum sequence length (truncates or pads).

    Returns:
        Token ID tensor of shape (1, max_length) ready for hypernetwork forward().
    """
    import logging as _logging  # noqa: PLC0415
    import zlib  # noqa: PLC0415

    import torch  # noqa: PLC0415

    min_trajectory_chars = 10
    if len(trajectory_text) < min_trajectory_chars:
        _logging.getLogger(__name__).warning(
            "Trajectory text is very short (%d chars < %d minimum). "
            "Generated adapter may be meaningless.",
            len(trajectory_text),
            min_trajectory_chars,
        )

    tokens: list[int] = []
    for i in range(0, len(trajectory_text) - 2):
        trigram = trajectory_text[i : i + 3]
        token_id = zlib.crc32(trigram.encode()) % vocab_size
        tokens.append(token_id)

    # Pad or truncate to max_length
    if len(tokens) < max_length:
        tokens.extend([0] * (max_length - len(tokens)))
    else:
        tokens = tokens[:max_length]

    return torch.tensor([tokens], dtype=torch.long)
generate_adapter
generate_adapter(
    hypernetwork: Any,
    trajectory_text: str,
    output_dir: str,
    base_model_id: str,
    vocab_size: int = DEFAULT_VOCAB_SIZE,
    max_length: int = 2048,
    device: str = "cpu",
) -> str

End-to-end: encode trajectory, run hypernetwork, save adapter.

Parameters:

Name Type Description Default
hypernetwork Any

A DocToLoraHypernetwork instance.

required
trajectory_text str

Text to encode into the adapter.

required
output_dir str

Directory to save the adapter files.

required
base_model_id str

HuggingFace model ID of the base model.

required
vocab_size int

Vocabulary size for tokenization.

DEFAULT_VOCAB_SIZE
max_length int

Max token sequence length.

2048
device str

Device for tensor operations.

'cpu'

Returns:

Type Description
str

Path to the saved adapter directory.

Source code in libs/model-training/src/model_training/hypernetwork.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def generate_adapter(
    hypernetwork: Any,
    trajectory_text: str,
    output_dir: str,
    base_model_id: str,
    vocab_size: int = DEFAULT_VOCAB_SIZE,
    max_length: int = 2048,
    device: str = "cpu",
) -> str:
    """End-to-end: encode trajectory, run hypernetwork, save adapter.

    Args:
        hypernetwork: A DocToLoraHypernetwork instance.
        trajectory_text: Text to encode into the adapter.
        output_dir: Directory to save the adapter files.
        base_model_id: HuggingFace model ID of the base model.
        vocab_size: Vocabulary size for tokenization.
        max_length: Max token sequence length.
        device: Device for tensor operations.

    Returns:
        Path to the saved adapter directory.
    """
    import torch  # noqa: PLC0415

    # Infer vocab size from hypernetwork's embedding if available
    h_vocab = getattr(
        getattr(hypernetwork, "token_embedding", None), "num_embeddings", None
    )
    effective_vocab = h_vocab if h_vocab is not None else vocab_size
    tokens = trajectory_to_tokens(trajectory_text, effective_vocab, max_length)
    tokens = tokens.to(device)

    with torch.no_grad():
        weights = hypernetwork(tokens)

    rank = getattr(hypernetwork, "rank", 8)
    target_mods = getattr(hypernetwork, "target_modules", ["q_proj", "v_proj"])
    save_hypernetwork_adapter(
        weights,
        output_dir,
        base_model_id,
        rank=rank,
        target_modules=target_mods,
    )
    return output_dir
save_hypernetwork_adapter
save_hypernetwork_adapter(
    weights: dict[str, "torch.Tensor"],
    output_dir: str,
    base_model_id: str,
    rank: int = 8,
    target_modules: list[str] | None = None,
) -> None

Serialize hypernetwork-generated LoRA weights in PEFT adapter format.

Writes: - adapter_model.safetensors: the LoRA weight tensors - adapter_config.json: PEFT-compatible configuration

Parameters:

Name Type Description Default
weights dict[str, 'torch.Tensor']

PEFT state_dict from DocToLoraHypernetwork.forward().

required
output_dir str

Directory to write adapter files to (created if needed).

required
base_model_id str

HuggingFace model ID of the base model.

required
rank int

LoRA rank. Default: 8.

8
target_modules list[str] | None

List of module names. Default: ["q_proj", "v_proj"].

None
Note

Does NOT include embed_tokens or lm_head — vLLM rejects these in adapters (per Phase 21-01 decision: no modules_to_save in LoraConfig).

Source code in libs/model-training/src/model_training/hypernetwork.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
def save_hypernetwork_adapter(
    weights: dict[str, "torch.Tensor"],
    output_dir: str,
    base_model_id: str,
    rank: int = 8,
    target_modules: list[str] | None = None,
) -> None:
    """Serialize hypernetwork-generated LoRA weights in PEFT adapter format.

    Writes:
    - adapter_model.safetensors: the LoRA weight tensors
    - adapter_config.json: PEFT-compatible configuration

    Args:
        weights: PEFT state_dict from DocToLoraHypernetwork.forward().
        output_dir: Directory to write adapter files to (created if needed).
        base_model_id: HuggingFace model ID of the base model.
        rank: LoRA rank. Default: 8.
        target_modules: List of module names. Default: ["q_proj", "v_proj"].

    Note:
        Does NOT include embed_tokens or lm_head — vLLM rejects these in adapters
        (per Phase 21-01 decision: no modules_to_save in LoraConfig).
    """
    from safetensors.torch import save_file  # noqa: PLC0415

    if target_modules is None:
        target_modules = ["q_proj", "v_proj"]

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Write adapter weights as safetensors
    safetensors_path = output_path / "adapter_model.safetensors"
    save_file(weights, str(safetensors_path))

    # Write PEFT-compatible adapter_config.json
    adapter_config: dict[str, object] = {
        "peft_type": "LORA",
        "r": rank,
        "lora_alpha": rank * 2,
        "target_modules": target_modules,
        "lora_dropout": 0.0,
        "bias": "none",
        "task_type": "CAUSAL_LM",
        "base_model_name_or_path": base_model_id,
        "inference_mode": True,
        "modules_to_save": None,
        "fan_in_fan_out": False,
    }

    config_path = output_path / "adapter_config.json"
    config_path.write_text(json.dumps(adapter_config, indent=2))

merging

Adapter merging strategies for evolutionary combination.

Implements TIES-Merging and DARE-Merging for combining multiple LoRA adapter state dicts into a single merged adapter. All GPU imports are deferred inside function bodies (INFRA-05 pattern).

Functions
ties_merge
ties_merge(
    state_dicts: list[dict[str, Any]], density: float = 0.5
) -> dict[str, Any]

Merge adapter state dicts using TIES-Merging.

Trim-Elect-Sign-Disjoint merge: for each parameter, trims values below density threshold, elects the majority sign (ignoring trimmed values), then averages only the values matching the elected sign.

Parameters:

Name Type Description Default
state_dicts list[dict[str, Any]]

List of state dicts (tensors) to merge.

required
density float

Fraction of values to keep per parameter (0.0 to 1.0).

0.5

Returns:

Type Description
dict[str, Any]

Merged state dict with same keys and shapes as inputs.

Raises:

Type Description
ValueError

If state_dicts is empty or density is outside [0.0, 1.0].

Source code in libs/model-training/src/model_training/merging.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def ties_merge(
    state_dicts: list[dict[str, Any]],
    density: float = 0.5,
) -> dict[str, Any]:
    """Merge adapter state dicts using TIES-Merging.

    Trim-Elect-Sign-Disjoint merge: for each parameter, trims values
    below density threshold, elects the majority sign (ignoring trimmed
    values), then averages only the values matching the elected sign.

    Args:
        state_dicts: List of state dicts (tensors) to merge.
        density: Fraction of values to keep per parameter (0.0 to 1.0).

    Returns:
        Merged state dict with same keys and shapes as inputs.

    Raises:
        ValueError: If *state_dicts* is empty or *density* is outside
            ``[0.0, 1.0]``.
    """
    if not state_dicts:
        raise ValueError(
            "state_dicts must not be empty; TIES-merge of zero inputs is undefined."
        )
    if not (0.0 <= density <= 1.0):
        raise ValueError(f"density must be between 0.0 and 1.0, got {density}")

    import torch

    merged: dict[str, Any] = {}
    keys = state_dicts[0].keys()

    for key in keys:
        original_dtype = state_dicts[0][key].dtype
        tensors = [sd[key].float() for sd in state_dicts]
        stacked = torch.stack(tensors)

        # Track which values survived trimming (True = kept)
        keep_mask = torch.ones_like(stacked, dtype=torch.bool)

        # Trim: zero out values below density threshold per tensor
        for i in range(len(tensors)):
            flat = stacked[i].abs().flatten()
            if flat.numel() == 0:
                continue
            k_keep = max(1, int(flat.numel() * density))
            threshold = torch.topk(flat, k_keep).values[-1]
            mask = stacked[i].abs() >= threshold
            keep_mask[i] = mask
            stacked[i] = stacked[i] * mask.float()

        # Elect sign: majority vote using only surviving (non-trimmed) values
        signs = torch.sign(stacked)
        signs_masked = signs * keep_mask.float()
        sign_sum = signs_masked.sum(dim=0)
        elected_sign = torch.sign(sign_sum)

        # Where elected_sign == 0 (true tie), average all surviving values
        tie_mask = elected_sign == 0

        # Disjoint merge: average only values matching elected sign
        matching = (signs == elected_sign.unsqueeze(0)).float()
        # For tied positions, include all values that survived trimming
        matching = torch.where(
            tie_mask.unsqueeze(0),
            keep_mask.float(),
            matching,
        )
        matching_vals = stacked * matching
        count = matching.sum(dim=0).clamp(min=1)
        merged[key] = (matching_vals.sum(dim=0) / count).to(original_dtype)

    return merged
dare_merge
dare_merge(
    state_dicts: list[dict[str, Any]],
    drop_rate: float = 0.1,
    seed: int | None = None,
) -> dict[str, Any]

Merge adapter state dicts using DARE-Merging.

Drop-And-REscale merge: randomly drops a fraction of values from each state dict, then averages the remaining values with rescaling.

Parameters:

Name Type Description Default
state_dicts list[dict[str, Any]]

List of state dicts to merge.

required
drop_rate float

Fraction of values to drop per parameter. Must be in [0.0, 1.0).

0.1
seed int | None

Optional RNG seed for reproducible merges.

None

Returns:

Type Description
dict[str, Any]

Merged state dict with same keys and shapes as inputs.

Raises:

Type Description
ValueError

If state_dicts is empty or drop_rate is outside [0.0, 1.0).

Source code in libs/model-training/src/model_training/merging.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def dare_merge(
    state_dicts: list[dict[str, Any]],
    drop_rate: float = 0.1,
    seed: int | None = None,
) -> dict[str, Any]:
    """Merge adapter state dicts using DARE-Merging.

    Drop-And-REscale merge: randomly drops a fraction of values from
    each state dict, then averages the remaining values with rescaling.

    Args:
        state_dicts: List of state dicts to merge.
        drop_rate: Fraction of values to drop per parameter. Must be in
            ``[0.0, 1.0)``.
        seed: Optional RNG seed for reproducible merges.

    Returns:
        Merged state dict with same keys and shapes as inputs.

    Raises:
        ValueError: If *state_dicts* is empty or *drop_rate* is outside
            ``[0.0, 1.0)``.
    """
    if not state_dicts:
        raise ValueError(
            "state_dicts must not be empty; DARE-merge of zero inputs is undefined."
        )
    if not (0.0 <= drop_rate < 1.0):
        raise ValueError(f"drop_rate must be in [0.0, 1.0), got {drop_rate}")

    import torch

    if seed is not None:
        torch.manual_seed(seed)

    merged: dict[str, Any] = {}
    keys = state_dicts[0].keys()
    scale = 1.0 / (1.0 - drop_rate)

    for key in keys:
        original_dtype = state_dicts[0][key].dtype
        tensors = [sd[key].float() for sd in state_dicts]
        stacked = torch.stack(tensors)

        # Drop and rescale each tensor
        for i in range(len(tensors)):
            if drop_rate > 0:
                mask = (torch.rand_like(stacked[i]) >= drop_rate).float()
                stacked[i] = stacked[i] * mask * scale

        # Average across state dicts
        merged[key] = stacked.mean(dim=0).to(original_dtype)

    return merged
load_adapter_state_dict
load_adapter_state_dict(
    adapter_path: str | Path,
) -> dict[str, Any]

Load a LoRA adapter state dict from a safetensors file or directory.

Accepts either a direct .safetensors file path or a PEFT adapter directory containing adapter_model.safetensors.

Parameters:

Name Type Description Default
adapter_path str | Path

Path to a .safetensors file or a directory containing adapter_model.safetensors.

required

Returns:

Type Description
dict[str, Any]

State dict mapping parameter names to tensors.

Source code in libs/model-training/src/model_training/merging.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def load_adapter_state_dict(adapter_path: str | Path) -> dict[str, Any]:
    """Load a LoRA adapter state dict from a safetensors file or directory.

    Accepts either a direct ``.safetensors`` file path or a PEFT adapter
    directory containing ``adapter_model.safetensors``.

    Args:
        adapter_path: Path to a ``.safetensors`` file or a directory
            containing ``adapter_model.safetensors``.

    Returns:
        State dict mapping parameter names to tensors.
    """
    from safetensors.torch import load_file

    path = Path(adapter_path)
    if path.is_dir():
        path = path / "adapter_model.safetensors"
    return load_file(str(path), device="cpu")

peft_utils

QLoRA PEFT configuration and adapter management.

All GPU library imports (peft, transformers, torch) are deferred inside function bodies to ensure CPU-only importability (INFRA-05).

Functions
build_qlora_config
build_qlora_config(
    rank: int,
    alpha: int,
    target_modules: list[str],
    dropout: float = 0.1,
) -> Any

Build a QLoRA configuration for PEFT fine-tuning.

Parameters:

Name Type Description Default
rank int

LoRA rank (dimensionality of low-rank matrices).

required
alpha int

LoRA alpha scaling factor.

required
target_modules list[str]

List of module names to apply LoRA to.

required
dropout float

Dropout probability for LoRA layers.

0.1

Returns:

Type Description
Any

A peft LoraConfig instance configured for QLoRA.

Example

config = build_qlora_config(rank=64, alpha=128, target_modules=["q_proj"])

Source code in libs/model-training/src/model_training/peft_utils.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def build_qlora_config(
    rank: int,
    alpha: int,
    target_modules: list[str],
    dropout: float = 0.1,
) -> Any:
    """Build a QLoRA configuration for PEFT fine-tuning.

    Args:
        rank: LoRA rank (dimensionality of low-rank matrices).
        alpha: LoRA alpha scaling factor.
        target_modules: List of module names to apply LoRA to.
        dropout: Dropout probability for LoRA layers.

    Returns:
        A peft LoraConfig instance configured for QLoRA.

    Example:
        >>> config = build_qlora_config(rank=64, alpha=128, target_modules=["q_proj"])
    """
    from peft import LoraConfig  # deferred — GPU/peft not available in CPU CI

    return LoraConfig(
        r=rank,
        lora_alpha=alpha,
        target_modules=target_modules,
        lora_dropout=dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
apply_lora_adapter
apply_lora_adapter(model: Any, config: Any) -> Any

Apply a LoRA adapter to a base model.

Parameters:

Name Type Description Default
model Any

The base model to wrap with LoRA.

required
config Any

The LoRA configuration (from build_qlora_config).

required

Returns:

Type Description
Any

The model wrapped with a LoRA adapter via peft.get_peft_model.

Example

adapted_model = apply_lora_adapter(base_model, lora_config)

Source code in libs/model-training/src/model_training/peft_utils.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def apply_lora_adapter(model: Any, config: Any) -> Any:
    """Apply a LoRA adapter to a base model.

    Args:
        model: The base model to wrap with LoRA.
        config: The LoRA configuration (from build_qlora_config).

    Returns:
        The model wrapped with a LoRA adapter via peft.get_peft_model.

    Example:
        >>> adapted_model = apply_lora_adapter(base_model, lora_config)
    """
    from peft import get_peft_model  # deferred — GPU/peft not available in CPU CI

    return get_peft_model(model, config)
merge_adapter
merge_adapter(model: Any) -> Any

Merge LoRA weights into the base model.

Parameters:

Name Type Description Default
model Any

A PEFT model with LoRA adapter applied.

required

Returns:

Type Description
Any

The base model with LoRA weights merged in.

Raises:

Type Description
NotImplementedError

Adapter merging is out of scope for Phase 21.

Example

merged = merge_adapter(peft_model)

Source code in libs/model-training/src/model_training/peft_utils.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def merge_adapter(model: Any) -> Any:
    """Merge LoRA weights into the base model.

    Args:
        model: A PEFT model with LoRA adapter applied.

    Returns:
        The base model with LoRA weights merged in.

    Raises:
        NotImplementedError: Adapter merging is out of scope for Phase 21.

    Example:
        >>> merged = merge_adapter(peft_model)
    """
    raise NotImplementedError(
        "merge_adapter is not yet implemented. "
        "It will merge LoRA weights into the base model and return the merged model."
    )

sakana_d2l

SakanaAI Doc-to-LoRA integration.

Wraps Sakana's pretrained HyperLoRA perceiver so it can be used through our hypernetwork interface (load_pretrained → generate_adapter).

The Sakana hypernetwork takes per-layer activations from a base model as input and produces LoRA adapter weights. This module handles: - Downloading the checkpoint from HuggingFace - Patching flash-attention assertions for CPU/MPS/non-flash environments - Extracting per-layer activations from the base model - Saving the generated LoRA weights in PEFT format

GPU imports are deferred inside function bodies per INFRA-05 pattern.

Functions
download_checkpoint
download_checkpoint(variant: str = DEFAULT_VARIANT) -> Path

Download Sakana's pretrained checkpoint from HuggingFace.

Parameters:

Name Type Description Default
variant str

Which checkpoint variant to download. Options: 'gemma_demo', 'gemma_2b_d2l', 'mistral_7b_d2l', 'qwen_4b_d2l'.

DEFAULT_VARIANT

Returns:

Type Description
Path

Path to the downloaded checkpoint file.

Source code in libs/model-training/src/model_training/sakana_d2l.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def download_checkpoint(
    variant: str = DEFAULT_VARIANT,
) -> Path:
    """Download Sakana's pretrained checkpoint from HuggingFace.

    Args:
        variant: Which checkpoint variant to download.
            Options: 'gemma_demo', 'gemma_2b_d2l', 'mistral_7b_d2l', 'qwen_4b_d2l'.

    Returns:
        Path to the downloaded checkpoint file.
    """
    # Determine filename based on variant
    if variant == "gemma_demo":
        hf_filename = "gemma_demo/checkpoint-80000/pytorch_model.bin"
    elif variant in ("gemma_2b_d2l", "mistral_7b_d2l", "qwen_4b_d2l"):
        hf_filename = f"{variant}/checkpoint-20000/pytorch_model.bin"
    else:
        msg = f"Unknown variant: {variant}"
        raise ValueError(msg)

    cached = LOCAL_CACHE_DIR / variant / "pytorch_model.bin"
    if cached.exists():
        logger.info("Using cached Sakana checkpoint: %s", cached)
        return cached

    from huggingface_hub import hf_hub_download  # noqa: PLC0415

    logger.info("Downloading Sakana checkpoint %s from %s...", variant, HF_REPO_ID)
    downloaded = Path(hf_hub_download(repo_id=HF_REPO_ID, filename=hf_filename))

    cached.parent.mkdir(parents=True, exist_ok=True)
    import shutil  # noqa: PLC0415

    shutil.copy2(downloaded, cached)
    logger.info("Cached to: %s", cached)
    return cached
load_sakana_checkpoint
load_sakana_checkpoint(
    checkpoint_path: str | Path | None = None,
    variant: str = DEFAULT_VARIANT,
    device: str = "cpu",
) -> tuple[Any, Any]

Load Sakana's HyperLoRA perceiver from checkpoint.

Downloads from HuggingFace if no local path is provided. Patches flash attention for CPU/MPS compatibility.

Parameters:

Name Type Description Default
checkpoint_path str | Path | None

Path to local checkpoint. If None, downloads from HF.

None
variant str

HF checkpoint variant (only used if checkpoint_path is None).

DEFAULT_VARIANT
device str

Device to load onto.

'cpu'

Returns:

Type Description
tuple[Any, Any]

Tuple of (hypernet, hypernet_config).

Source code in libs/model-training/src/model_training/sakana_d2l.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def load_sakana_checkpoint(
    checkpoint_path: str | Path | None = None,
    variant: str = DEFAULT_VARIANT,
    device: str = "cpu",
) -> tuple[Any, Any]:
    """Load Sakana's HyperLoRA perceiver from checkpoint.

    Downloads from HuggingFace if no local path is provided.
    Patches flash attention for CPU/MPS compatibility.

    Args:
        checkpoint_path: Path to local checkpoint. If None, downloads from HF.
        variant: HF checkpoint variant (only used if checkpoint_path is None).
        device: Device to load onto.

    Returns:
        Tuple of (hypernet, hypernet_config).
    """
    import torch  # noqa: PLC0415

    _patch_flash_attention()

    # Pre-import flash_attn before torch.load — the unpickler triggers
    # ctx_to_lora module imports in a context that breaks flash_attn
    # resolution if it hasn't been imported yet.
    try:
        import flash_attn.flash_attn_interface  # noqa: F401,PLC0415
    except ImportError:
        pass

    if checkpoint_path is None:
        checkpoint_path = download_checkpoint(variant)

    logger.info("Loading Sakana checkpoint: %s", checkpoint_path)
    sd = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)

    hc = sd["hypernet_config"]
    logger.info(
        "HypernetConfig: latent_size=%d, lora_r=%d, base_model=%s",
        hc.latent_size,
        hc.lora_config.r,
        sd["base_model_name_or_path"],
    )

    from ctx_to_lora.modeling.hypernet import HyperLoRA  # noqa: PLC0415
    from shared.hardware import resolve_model_dtype  # noqa: PLC0415

    hypernet_param_count = sum(
        v.numel() for v in sd.values() if isinstance(v, torch.Tensor)
    )
    hypernet_dtype = resolve_model_dtype(
        param_count=hypernet_param_count,
        device=device,
    )
    logger.info("HyperLoRA dtype resolved to %s", hypernet_dtype)
    # Suppress "Flash Attention 2 without specifying a torch dtype" warning
    # by setting the dtype on the config before HyperLoRA instantiation.
    if hasattr(hc, "aggregator_config"):
        ac = hc.aggregator_config
        if hasattr(ac, "torch_dtype"):
            ac.torch_dtype = hypernet_dtype
    hypernet = HyperLoRA(hc).to(hypernet_dtype)

    # Load ALL hypernet weights from checkpoint (not just a prefix subset).
    # The checkpoint contains aggregator.*, head.*, scaler_{A,B}.*,
    # bias_{A,B}.*, and layers.0.* — all are required for correct
    # adapter generation.  scaler_B in particular defaults to zeros,
    # so skipping it zeroes out every lora_B matrix.
    model_keys = set(hypernet.state_dict().keys())
    hypernet_sd = {k: v for k, v in sd.items() if k in model_keys}

    loaded = hypernet.load_state_dict(hypernet_sd, strict=False)
    logger.info(
        "Loaded %d/%d hypernet weight tensors",
        len(hypernet_sd),
        len(model_keys),
    )
    if loaded.missing_keys:
        logger.warning(
            "Missing keys (%d, will use defaults): %s",
            len(loaded.missing_keys),
            loaded.missing_keys,
        )
    if loaded.unexpected_keys:
        logger.info("Unexpected keys: %d", len(loaded.unexpected_keys))

    hypernet = hypernet.to(device)
    hypernet.eval()

    param_count = sum(p.numel() for p in hypernet.parameters())
    logger.info("HyperLoRA params: %d", param_count)

    return hypernet, hc
transfer_aggregator_weights
transfer_aggregator_weights(
    hypernet: Any, checkpoint_path: str | Path
) -> Any

Load aggregator weights from a Sakana checkpoint into a HyperLoRA instance.

Loads only aggregator. weights from the checkpoint (not head.), freezes all aggregator parameters (requires_grad=False), and leaves head.* at PyTorch default initialization for Phase 29 training against the new target model.

This enables reuse of the pretrained Perceiver aggregator across different target model architectures. The aggregator maps document embeddings to LoRA weight space and is model-agnostic; only the head needs retraining per target model.

Parameters:

Name Type Description Default
hypernet Any

The HyperLoRA model to load weights into (mutated in-place).

required
checkpoint_path str | Path

Path to the Sakana checkpoint (.bin file).

required

Returns:

Type Description
Any

The mutated hypernet (returned for chaining convenience).

Source code in libs/model-training/src/model_training/sakana_d2l.py
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def transfer_aggregator_weights(hypernet: Any, checkpoint_path: str | Path) -> Any:
    """Load aggregator weights from a Sakana checkpoint into a HyperLoRA instance.

    Loads only aggregator.* weights from the checkpoint (not head.*), freezes all
    aggregator parameters (requires_grad=False), and leaves head.* at PyTorch default
    initialization for Phase 29 training against the new target model.

    This enables reuse of the pretrained Perceiver aggregator across different target
    model architectures. The aggregator maps document embeddings to LoRA weight space
    and is model-agnostic; only the head needs retraining per target model.

    Args:
        hypernet: The HyperLoRA model to load weights into (mutated in-place).
        checkpoint_path: Path to the Sakana checkpoint (.bin file).

    Returns:
        The mutated hypernet (returned for chaining convenience).
    """
    import torch  # noqa: PLC0415

    sd = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)

    # Filter to only aggregator.* tensors that exist in the target model
    model_sd = hypernet.state_dict()
    aggregator_sd = {
        k: v
        for k, v in sd.items()
        if k.startswith("aggregator.") and isinstance(v, torch.Tensor) and k in model_sd
    }

    logger.info(
        "Loading %d aggregator weights from checkpoint: %s",
        len(aggregator_sd),
        checkpoint_path,
    )

    loaded = hypernet.load_state_dict(aggregator_sd, strict=False)
    _assert_transfer_integrity(hypernet, loaded)

    # Freeze all aggregator parameters — only head will be trained
    frozen_count = 0
    trainable_count = 0
    for name, param in hypernet.named_parameters():
        if name.startswith("aggregator."):
            param.requires_grad_(False)
            frozen_count += 1
        else:
            trainable_count += 1

    logger.info(
        "Froze %d aggregator params; %d params (head.*) remain trainable",
        frozen_count,
        trainable_count,
    )

    return hypernet
get_aggregator_config
get_aggregator_config(checkpoint_path: str | Path) -> Any

Extract the Perceiver aggregator structural config from a Sakana checkpoint.

Reads the aggregator_config from the checkpoint's HypernetConfig so that d2l_config.py can populate the aggregator_config=None placeholder set in Phase 25.

Parameters:

Name Type Description Default
checkpoint_path str | Path

Path to the Sakana checkpoint (.bin file).

required

Returns:

Type Description
Any

The aggregator_config object from the checkpoint's HypernetConfig.

Raises:

Type Description
ValueError

If the checkpoint's aggregator_config is None (predates this field).

Source code in libs/model-training/src/model_training/sakana_d2l.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def get_aggregator_config(checkpoint_path: str | Path) -> Any:
    """Extract the Perceiver aggregator structural config from a Sakana checkpoint.

    Reads the aggregator_config from the checkpoint's HypernetConfig so that
    d2l_config.py can populate the aggregator_config=None placeholder set in Phase 25.

    Args:
        checkpoint_path: Path to the Sakana checkpoint (.bin file).

    Returns:
        The aggregator_config object from the checkpoint's HypernetConfig.

    Raises:
        ValueError: If the checkpoint's aggregator_config is None (predates this field).
    """
    import torch  # noqa: PLC0415

    sd = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
    hc = sd["hypernet_config"]

    if hc.aggregator_config is None:
        msg = (
            "aggregator_config is None in checkpoint — "
            "checkpoint may predate this field"
        )
        raise ValueError(msg)

    return hc.aggregator_config
extract_activations
extract_activations(
    text: str,
    base_model_name: str,
    layer_indices: list[int],
    device: str = "cpu",
    max_length: int = 512,
) -> tuple[Any, Any]

Extract per-layer hidden state activations from the base model.

Backward-compatible wrapper around extract_activations_with_model(). Loads model and tokenizer, delegates extraction, then cleans up.

Parameters:

Name Type Description Default
text str

Input text to process.

required
base_model_name str

HuggingFace model ID for the base model.

required
layer_indices list[int]

Which layers to extract activations from.

required
device str

Device for computation.

'cpu'
max_length int

Max token sequence length.

512

Returns:

Type Description
Any

Tuple of (features, attention_mask) ready for HyperLoRA.

Any

features shape: (1, num_layers, seq_len, hidden_dim)

tuple[Any, Any]

attention_mask shape: (1, seq_len)

Source code in libs/model-training/src/model_training/sakana_d2l.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
def extract_activations(
    text: str,
    base_model_name: str,
    layer_indices: list[int],
    device: str = "cpu",
    max_length: int = 512,
) -> tuple[Any, Any]:
    """Extract per-layer hidden state activations from the base model.

    Backward-compatible wrapper around extract_activations_with_model().
    Loads model and tokenizer, delegates extraction, then cleans up.

    Args:
        text: Input text to process.
        base_model_name: HuggingFace model ID for the base model.
        layer_indices: Which layers to extract activations from.
        device: Device for computation.
        max_length: Max token sequence length.

    Returns:
        Tuple of (features, attention_mask) ready for HyperLoRA.
        features shape: (1, num_layers, seq_len, hidden_dim)
        attention_mask shape: (1, seq_len)
    """
    if not text or not text.strip():
        raise ValueError(
            "extract_activations called with empty text; adapter would be meaningless."
        )

    import torch  # noqa: PLC0415
    from shared.hardware import resolve_model_dtype  # noqa: PLC0415
    from transformers import AutoModelForCausalLM, AutoTokenizer  # noqa: PLC0415

    from model_training.d2l_probe import extract_activations_with_model  # noqa: PLC0415

    logger.info("Loading base model %s for activation extraction...", base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    # Estimate param count from model config to resolve dtype
    from transformers import AutoConfig  # noqa: PLC0415

    config = AutoConfig.from_pretrained(base_model_name)
    estimated_params = getattr(config, "num_parameters", None)
    if estimated_params is None:
        # Rough estimate: vocab_size * hidden + num_layers * 4 * hidden^2
        vocab = getattr(config, "vocab_size", 256000)
        hidden = getattr(config, "hidden_size", 2304)
        n_layers = getattr(config, "num_hidden_layers", 26)
        estimated_params = vocab * hidden + n_layers * 4 * hidden * hidden

    # Account for inference model already on GPU as overhead
    overhead = 0
    if device != "cpu" and torch.cuda.is_available():
        overhead = torch.cuda.memory_allocated(0)

    activation_dtype = resolve_model_dtype(
        param_count=estimated_params,
        device=device,
        overhead_bytes=overhead,
    )
    logger.info("Activation extraction dtype resolved to %s", activation_dtype)

    model: Any = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        dtype=activation_dtype,
    )
    model = model.to(device)  # type: ignore[assignment]
    model.eval()

    result = extract_activations_with_model(
        text=text,
        model=model,
        tokenizer=tokenizer,
        layer_indices=layer_indices,
        max_length=max_length,
    )

    del model
    if device != "cpu":
        torch.cuda.empty_cache()
    return result
generate_adapter_from_sakana
generate_adapter_from_sakana(
    text: str,
    output_dir: str,
    checkpoint_path: str | Path | None = None,
    variant: str = DEFAULT_VARIANT,
    base_model_name: str | None = None,
    device: str = "cpu",
    max_length: int = 512,
    scaling_factor: float = 0.16,
) -> str

End-to-end: text → base model activations → HyperLoRA → PEFT adapter.

This is the main entry point. It: 1. Loads the Sakana pretrained perceiver (downloading if needed) 2. Runs text through the base model to get per-layer activations 3. Feeds activations through the perceiver to generate LoRA weights 4. Saves weights in PEFT-compatible format

Parameters:

Name Type Description Default
text str

Input text (trajectory, document, context) to encode.

required
output_dir str

Directory to save the PEFT adapter files.

required
checkpoint_path str | Path | None

Path to local checkpoint, or None to download.

None
variant str

HF checkpoint variant if downloading.

DEFAULT_VARIANT
base_model_name str | None

Override base model. If None, uses the one from checkpoint.

None
device str

Device for computation.

'cpu'
max_length int

Maximum token sequence length for activation extraction.

512
scaling_factor float

Adapter scaling multiplier (0-1, default from config).

0.16

Returns:

Type Description
str

Path to the saved adapter directory.

Source code in libs/model-training/src/model_training/sakana_d2l.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def generate_adapter_from_sakana(
    text: str,
    output_dir: str,
    checkpoint_path: str | Path | None = None,
    variant: str = DEFAULT_VARIANT,
    base_model_name: str | None = None,
    device: str = "cpu",
    max_length: int = 512,
    scaling_factor: float = 0.16,
) -> str:
    """End-to-end: text → base model activations → HyperLoRA → PEFT adapter.

    This is the main entry point. It:
      1. Loads the Sakana pretrained perceiver (downloading if needed)
      2. Runs text through the base model to get per-layer activations
      3. Feeds activations through the perceiver to generate LoRA weights
      4. Saves weights in PEFT-compatible format

    Args:
        text: Input text (trajectory, document, context) to encode.
        output_dir: Directory to save the PEFT adapter files.
        checkpoint_path: Path to local checkpoint, or None to download.
        variant: HF checkpoint variant if downloading.
        base_model_name: Override base model. If None, uses the one from checkpoint.
        device: Device for computation.
        max_length: Maximum token sequence length for activation extraction.
        scaling_factor: Adapter scaling multiplier (0-1, default from config).

    Returns:
        Path to the saved adapter directory.
    """
    import torch  # noqa: PLC0415

    hypernet, hc = load_sakana_checkpoint(checkpoint_path, variant, device)

    # Determine base model from checkpoint config
    if base_model_name is None:
        # Load checkpoint just to read base_model_name_or_path
        if checkpoint_path is None:
            checkpoint_path = download_checkpoint(variant)
        sd = torch.load(str(checkpoint_path), map_location="cpu", weights_only=False)
        base_model_name = sd["base_model_name_or_path"]
        del sd

    logger.info("Base model: %s", base_model_name)

    # Extract activations from base model
    layer_indices = list(hc.layer_indices)
    features, attn_mask = extract_activations(
        text=text,
        base_model_name=base_model_name,
        layer_indices=layer_indices,
        device=device,
        max_length=max_length,
    )

    # Generate LoRA weights via perceiver
    logger.info("Generating LoRA weights via HyperLoRA perceiver...")
    with torch.no_grad():
        lora_dict, layernorm_dict = hypernet.generate_weights(features, attn_mask, None)

    # Combine generated weights with bias — Sakana concatenates bias as extra
    # rank dimensions (rank 8 → 16 for single chunk). This is how the
    # checkpoint was trained and evaluated.
    from ctx_to_lora.modeling.lora_merger import combine_lora as _combine_lora

    n_chunks = torch.ones(1, dtype=torch.int32)
    lora_bias = hypernet.get_head_bias() if hypernet.config.use_bias else None
    lora_dict = _combine_lora(lora_dict, n_chunks, lora_bias=lora_bias)

    # Save as PEFT adapter
    _save_sakana_adapter(
        lora_dict=lora_dict,
        output_dir=output_dir,
        base_model_name=base_model_name,
        hc=hc,
        scaling_factor=scaling_factor,
    )

    # Free hypernet VRAM — it's not needed after adapter weights are saved
    del hypernet, lora_dict, layernorm_dict, features, attn_mask
    if device != "cpu":
        torch.cuda.empty_cache()

    return output_dir

trainer

QLoRA training orchestrator.

All GPU-dependent imports (datasets, transformers, trl, torch) are deferred inside function bodies to ensure CPU-only importability (INFRA-05).

Module-level imports: stdlib only.

Functions
train_qlora
train_qlora(
    session_id: str,
    adapter_id: str,
    output_dir: str,
    *,
    base_model_id: str | None = None,
    task_type: str = "code-gen",
    rank: int = 64,
    alpha: int = 128,
    epochs: int = 3,
    learning_rate: float = 0.0002,
) -> str

Train a QLoRA adapter from a recorded coding trajectory.

Orchestrates the full training pipeline: load trajectory, format as SFT messages, build dataset, load model with NF4 quantization, train with SFT, and save the adapter to output_dir.

All GPU imports are deferred to this function body; the module is safe to import in CPU-only environments.

Parameters:

Name Type Description Default
session_id str

Trajectory session ID to train from.

required
adapter_id str

Unique identifier for the resulting adapter.

required
output_dir str

Directory to save the trained adapter weights.

required
base_model_id str | None

HuggingFace model ID. Defaults to RUNE_BASE_MODEL env var or "Qwen/Qwen2.5-Coder-7B-Instruct".

None
task_type str

Task category (e.g. 'code-gen', 'bug-fix').

'code-gen'
rank int

LoRA rank.

64
alpha int

LoRA alpha scaling factor.

128
epochs int

Number of training epochs.

3
learning_rate float

Optimizer learning rate.

0.0002

Returns:

Type Description
str

output_dir path where the adapter was saved.

Raises:

Type Description
FileNotFoundError

If the trajectory file does not exist.

ValueError

If the trajectory is not successful or has no SFT messages.

Source code in libs/model-training/src/model_training/trainer.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def train_qlora(
    session_id: str,
    adapter_id: str,
    output_dir: str,
    *,
    base_model_id: str | None = None,
    task_type: str = "code-gen",
    rank: int = 64,
    alpha: int = 128,
    epochs: int = 3,
    learning_rate: float = 2e-4,
) -> str:
    """Train a QLoRA adapter from a recorded coding trajectory.

    Orchestrates the full training pipeline: load trajectory, format as SFT
    messages, build dataset, load model with NF4 quantization, train with SFT,
    and save the adapter to output_dir.

    All GPU imports are deferred to this function body; the module is safe to
    import in CPU-only environments.

    Args:
        session_id: Trajectory session ID to train from.
        adapter_id: Unique identifier for the resulting adapter.
        output_dir: Directory to save the trained adapter weights.
        base_model_id: HuggingFace model ID. Defaults to RUNE_BASE_MODEL env var
            or "Qwen/Qwen2.5-Coder-7B-Instruct".
        task_type: Task category (e.g. 'code-gen', 'bug-fix').
        rank: LoRA rank.
        alpha: LoRA alpha scaling factor.
        epochs: Number of training epochs.
        learning_rate: Optimizer learning rate.

    Returns:
        output_dir path where the adapter was saved.

    Raises:
        FileNotFoundError: If the trajectory file does not exist.
        ValueError: If the trajectory is not successful or has no SFT messages.
    """
    # Deferred GPU imports — all in function body for CPU-only importability
    import torch
    from datasets import Dataset
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from trl import SFTConfig, SFTTrainer

    from model_training.peft_utils import build_qlora_config
    from model_training.trajectory import format_for_sft, load_trajectory

    # Resolve model ID — read env var inside function body for monkeypatch testability
    model_id = base_model_id or os.environ.get(
        "RUNE_BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct"
    )

    # Load and format trajectory
    trajectory = load_trajectory(session_id)  # raises FileNotFoundError if missing
    messages = format_for_sft(trajectory)
    if not messages:
        raise ValueError(
            f"Trajectory {session_id} is not successful or has no SFT messages"
        )

    # Build dataset
    dataset = Dataset.from_list([{"messages": messages}])

    # NF4 quantization config (bfloat16 compute dtype prevents silent NaN loss)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Build LoRA config (bfloat16 is set in BitsAndBytesConfig, not LoraConfig)
    lora_config = build_qlora_config(
        rank=rank,
        alpha=alpha,
        target_modules=["q_proj", "v_proj"],
        dropout=0.1,
    )

    # Training arguments
    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=epochs,
        learning_rate=learning_rate,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        bf16=True,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        save_strategy="no",
        logging_steps=1,
        report_to="none",
        eval_strategy="no",
    )

    # Create and run trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        peft_config=lora_config,
        processing_class=tokenizer,
    )
    trainer.train()

    # Save adapter weights only (PEFT-aware: safetensors + adapter_config.json)
    trainer.save_model(output_dir)

    return output_dir
train_and_register
train_and_register(
    session_id: str,
    adapter_id: str,
    *,
    base_model_id: str | None = None,
    task_type: str = "code-gen",
    rank: int = 64,
    alpha: int = 128,
    epochs: int = 3,
    learning_rate: float = 0.0002,
    database_url: str | None = None,
) -> str

Train a QLoRA adapter and register it in the AdapterRegistry.

Combines train_qlora() with AdapterRegistry.store() to produce a fully registered adapter ready for vLLM serving.

Parameters:

Name Type Description Default
session_id str

Trajectory session ID to train from.

required
adapter_id str

Unique identifier for the resulting adapter.

required
base_model_id str | None

HuggingFace model ID. Defaults to RUNE_BASE_MODEL env var.

None
task_type str

Task category (e.g. 'code-gen', 'bug-fix').

'code-gen'
rank int

LoRA rank.

64
alpha int

LoRA alpha scaling factor.

128
epochs int

Number of training epochs.

3
learning_rate float

Optimizer learning rate.

0.0002
database_url str | None

SQLAlchemy database URL. Defaults to RUNE_DATABASE_URL env var or "sqlite:///{home}/.rune/rune.db".

None

Returns:

Type Description
str

adapter_id of the registered adapter.

Raises:

Type Description
FileNotFoundError

If the trajectory file does not exist.

ValueError

If the trajectory is not successful or has no SFT messages.

Source code in libs/model-training/src/model_training/trainer.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def train_and_register(
    session_id: str,
    adapter_id: str,
    *,
    base_model_id: str | None = None,
    task_type: str = "code-gen",
    rank: int = 64,
    alpha: int = 128,
    epochs: int = 3,
    learning_rate: float = 2e-4,
    database_url: str | None = None,
) -> str:
    """Train a QLoRA adapter and register it in the AdapterRegistry.

    Combines train_qlora() with AdapterRegistry.store() to produce a fully
    registered adapter ready for vLLM serving.

    Args:
        session_id: Trajectory session ID to train from.
        adapter_id: Unique identifier for the resulting adapter.
        base_model_id: HuggingFace model ID. Defaults to RUNE_BASE_MODEL env var.
        task_type: Task category (e.g. 'code-gen', 'bug-fix').
        rank: LoRA rank.
        alpha: LoRA alpha scaling factor.
        epochs: Number of training epochs.
        learning_rate: Optimizer learning rate.
        database_url: SQLAlchemy database URL. Defaults to RUNE_DATABASE_URL env var
            or "sqlite:///{home}/.rune/rune.db".

    Returns:
        adapter_id of the registered adapter.

    Raises:
        FileNotFoundError: If the trajectory file does not exist.
        ValueError: If the trajectory is not successful or has no SFT messages.
    """
    from adapter_registry.models import AdapterRecord
    from adapter_registry.registry import AdapterRegistry
    from sqlalchemy import create_engine

    # Resolve adapter output dir — read env var inside function body
    adapter_base = os.environ.get("RUNE_ADAPTER_DIR")
    if adapter_base:
        adapter_dir = Path(adapter_base) / adapter_id
    else:
        adapter_dir = Path.home() / ".rune" / "adapters" / adapter_id
    adapter_dir.mkdir(parents=True, exist_ok=True)

    # Resolve model ID inside function body for testability
    model_id = base_model_id or os.environ.get(
        "RUNE_BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct"
    )

    output_dir = str(adapter_dir)
    train_qlora(
        session_id=session_id,
        adapter_id=adapter_id,
        output_dir=output_dir,
        base_model_id=model_id,
        task_type=task_type,
        rank=rank,
        alpha=alpha,
        epochs=epochs,
        learning_rate=learning_rate,
    )

    # Compute file hash and size from the saved safetensors file
    adapter_file = adapter_dir / "adapter_model.safetensors"
    file_hash = hashlib.sha256(adapter_file.read_bytes()).hexdigest()
    file_size_bytes = adapter_file.stat().st_size

    # Build the adapter record
    record = AdapterRecord(
        id=adapter_id,
        version=1,
        task_type=task_type,
        base_model_id=model_id,
        rank=rank,
        created_at=datetime.now(tz=timezone.utc).isoformat(),
        file_path=output_dir,
        file_hash=file_hash,
        file_size_bytes=file_size_bytes,
        source="qlora",
        session_id=session_id,
    )

    # Resolve DB URL inside function body for testability
    db_url = database_url or os.environ.get(
        "RUNE_DATABASE_URL",
        f"sqlite:///{Path.home() / '.rune' / 'rune.db'}",
    )
    engine = create_engine(db_url)
    registry = AdapterRegistry(engine=engine)
    registry.store(record)

    return adapter_id

trajectory

Trajectory recording and formatting for coding session distillation.

Provides functions to persist, load, and convert coding session trajectories into SFT-compatible chat format for LoRA fine-tuning pipelines.

Functions
record_trajectory
record_trajectory(
    session_id: str,
    steps: list[dict[str, Any]],
    outcome: Optional[str] = None,
    *,
    task_description: str = "",
    task_type: str = "",
    adapter_ids: list[str] | None = None,
) -> dict[str, Any]

Persist a coding session trajectory to disk for future distillation.

Parameters:

Name Type Description Default
session_id str

Unique identifier for the coding session.

required
steps list[dict[str, Any]]

List of step dicts, each containing attempt results.

required
outcome Optional[str]

Final session result ('success', 'exhausted', or None).

None
task_description str

Natural language description of the coding task.

''
task_type str

Category of task (e.g. 'function', 'class', 'refactor').

''
adapter_ids list[str] | None

LoRA adapter IDs used during the session.

None

Returns:

Type Description
dict[str, Any]

A dict with 'session_id' and 'file_path' keys.

Source code in libs/model-training/src/model_training/trajectory.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def record_trajectory(
    session_id: str,
    steps: list[dict[str, Any]],
    outcome: Optional[str] = None,
    *,
    task_description: str = "",
    task_type: str = "",
    adapter_ids: list[str] | None = None,
) -> dict[str, Any]:
    """Persist a coding session trajectory to disk for future distillation.

    Args:
        session_id: Unique identifier for the coding session.
        steps: List of step dicts, each containing attempt results.
        outcome: Final session result ('success', 'exhausted', or None).
        task_description: Natural language description of the coding task.
        task_type: Category of task (e.g. 'function', 'class', 'refactor').
        adapter_ids: LoRA adapter IDs used during the session.

    Returns:
        A dict with 'session_id' and 'file_path' keys.
    """
    trajectory_dir = _get_trajectory_dir()
    trajectory_dir.mkdir(parents=True, exist_ok=True)

    file_path = trajectory_dir / f"{session_id}.json"
    timestamp = datetime.now(tz=timezone.utc).isoformat()

    trajectory: dict[str, Any] = {
        "session_id": session_id,
        "task_description": task_description,
        "task_type": task_type,
        "adapter_ids": adapter_ids if adapter_ids is not None else [],
        "outcome": outcome,
        "timestamp": timestamp,
        "steps": steps,
    }

    file_path.write_text(json.dumps(trajectory, indent=2))

    return {"session_id": session_id, "file_path": str(file_path)}
load_trajectory
load_trajectory(trajectory_id: str) -> dict[str, Any]

Load a stored trajectory by session ID.

Parameters:

Name Type Description Default
trajectory_id str

The session ID used as the filename (without .json).

required

Returns:

Type Description
dict[str, Any]

A dict containing the full trajectory data including steps and metadata.

Raises:

Type Description
FileNotFoundError

If no trajectory file exists for the given ID.

Source code in libs/model-training/src/model_training/trajectory.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def load_trajectory(trajectory_id: str) -> dict[str, Any]:
    """Load a stored trajectory by session ID.

    Args:
        trajectory_id: The session ID used as the filename (without .json).

    Returns:
        A dict containing the full trajectory data including steps and metadata.

    Raises:
        FileNotFoundError: If no trajectory file exists for the given ID.
    """
    trajectory_dir = _get_trajectory_dir()
    file_path = trajectory_dir / f"{trajectory_id}.json"
    # Let FileNotFoundError propagate naturally if file does not exist
    return json.loads(file_path.read_text())  # type: ignore[no-any-return]
format_for_sft
format_for_sft(
    trajectory: dict[str, Any],
) -> list[dict[str, str]]

Convert a trajectory into SFT-compatible chat format.

Only successful trajectories (outcome == 'success') produce output. Extracts the final step where tests_passed is True as the assistant message.

Parameters:

Name Type Description Default
trajectory dict[str, Any]

A trajectory dict as returned by load_trajectory.

required

Returns:

Type Description
list[dict[str, str]]

A list of 3 message dicts ([system, user, assistant]) for successful

list[dict[str, str]]

trajectories, or an empty list if the trajectory did not succeed.

Source code in libs/model-training/src/model_training/trajectory.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def format_for_sft(trajectory: dict[str, Any]) -> list[dict[str, str]]:
    """Convert a trajectory into SFT-compatible chat format.

    Only successful trajectories (outcome == 'success') produce output.
    Extracts the final step where tests_passed is True as the assistant message.

    Args:
        trajectory: A trajectory dict as returned by load_trajectory.

    Returns:
        A list of 3 message dicts ([system, user, assistant]) for successful
        trajectories, or an empty list if the trajectory did not succeed.
    """
    if trajectory.get("outcome") != "success":
        return []

    steps: list[dict[str, Any]] = trajectory.get("steps", [])
    successful_step = next(
        (s for s in reversed(steps) if s.get("tests_passed")),
        None,
    )

    if successful_step is None:
        return []

    task_description: str = trajectory.get("task_description", "")
    generated_code: str = successful_step.get("generated_code", "")

    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": task_description},
        {"role": "assistant", "content": generated_code},
    ]