| """ |
| neural_data.py — Training data manager for MLX LoRA fine-tuning. |
| |
| Manages a rolling buffer of recent conversation turns and a persistent |
| replay buffer for anti-catastrophic-forgetting experience replay. |
| """ |
|
|
| import json |
| import random |
| import time |
| from collections import deque |
| from pathlib import Path |
| from typing import Optional |
|
|
|
|
| class TrainingExample: |
| """A single training example (conversation turn).""" |
|
|
| __slots__ = ("messages", "timestamp", "token_count", "session_id") |
|
|
| def __init__(self, messages: list[dict], timestamp: float = 0, |
| token_count: int = 0, session_id: str = ""): |
| self.messages = messages |
| self.timestamp = timestamp or time.time() |
| self.token_count = token_count |
| self.session_id = session_id |
|
|
| def to_dict(self) -> dict: |
| return { |
| "messages": self.messages, |
| "timestamp": self.timestamp, |
| "token_count": self.token_count, |
| "session_id": self.session_id, |
| } |
|
|
| @classmethod |
| def from_dict(cls, d: dict) -> "TrainingExample": |
| return cls( |
| messages=d["messages"], |
| timestamp=d.get("timestamp", 0), |
| token_count=d.get("token_count", 0), |
| session_id=d.get("session_id", ""), |
| ) |
|
|
|
|
| class TrainingDataManager: |
| """Manages rolling buffer + persistent replay for LoRA training.""" |
|
|
| def __init__(self, rolling_size: int = 100, replay_size: int = 500, |
| replay_path: str = "", min_response_tokens: int = 10): |
| self.rolling_size = rolling_size |
| self.replay_size = replay_size |
| self.min_response_tokens = min_response_tokens |
| self.replay_path = replay_path |
|
|
| self._rolling: deque[TrainingExample] = deque(maxlen=rolling_size) |
| self._replay: list[TrainingExample] = [] |
| self._total_added = 0 |
|
|
| if replay_path: |
| self._load_replay() |
|
|
| @property |
| def rolling_count(self) -> int: |
| return len(self._rolling) |
|
|
| @property |
| def replay_count(self) -> int: |
| return len(self._replay) |
|
|
| @property |
| def total_added(self) -> int: |
| return self._total_added |
|
|
| def add_turn(self, user_text: str, assistant_text: str, |
| system_prompt: str = "", session_id: str = "") -> bool: |
| """Add a conversation turn to the training buffer. |
| |
| Returns True if the example was accepted (not filtered). |
| """ |
| |
| approx_tokens = len(assistant_text.split()) |
| if approx_tokens < self.min_response_tokens: |
| return False |
|
|
| |
| if not assistant_text.strip(): |
| return False |
|
|
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": user_text}) |
| messages.append({"role": "assistant", "content": assistant_text}) |
|
|
| example = TrainingExample( |
| messages=messages, |
| token_count=approx_tokens, |
| session_id=session_id, |
| ) |
|
|
| self._rolling.append(example) |
| self._total_added += 1 |
|
|
| |
| if len(self._replay) < self.replay_size: |
| self._replay.append(example) |
| else: |
| idx = random.randint(0, self._total_added - 1) |
| if idx < self.replay_size: |
| self._replay[idx] = example |
|
|
| return True |
|
|
| def get_training_batch(self, batch_size: int = 1, |
| replay_ratio: float = 0.3) -> list[TrainingExample]: |
| """Get a training batch mixing recent and replay examples. |
| |
| Args: |
| batch_size: Total examples in batch. 0 = all available data. |
| replay_ratio: Fraction of batch from replay buffer (0.0-1.0) |
| |
| Returns: |
| List of TrainingExample |
| """ |
| if not self._rolling: |
| return [] |
|
|
| |
| if batch_size <= 0: |
| batch = list(self._rolling) |
| if self._replay: |
| |
| rolling_set = {id(ex) for ex in self._rolling} |
| for ex in self._replay: |
| if id(ex) not in rolling_set: |
| batch.append(ex) |
| random.shuffle(batch) |
| return batch |
|
|
| n_replay = int(batch_size * replay_ratio) |
| n_recent = batch_size - n_replay |
|
|
| batch = [] |
|
|
| |
| recent = list(self._rolling) |
| if n_recent > 0: |
| recent_sample = recent[-n_recent:] if len(recent) >= n_recent else recent |
| batch.extend(recent_sample) |
|
|
| |
| if n_replay > 0 and self._replay: |
| replay_sample = random.sample( |
| self._replay, |
| min(n_replay, len(self._replay)) |
| ) |
| batch.extend(replay_sample) |
|
|
| random.shuffle(batch) |
| return batch |
|
|
| def get_recent(self, n: int = 5) -> list[TrainingExample]: |
| """Get the N most recent training examples.""" |
| return list(self._rolling)[-n:] |
|
|
| def save_rolling(self, path: str = ""): |
| """Save rolling buffer to disk.""" |
| path = path or str(Path(self.replay_path).parent / "buffer.jsonl") |
| Path(path).parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "w") as f: |
| for ex in self._rolling: |
| f.write(json.dumps(ex.to_dict()) + "\n") |
|
|
| def load_rolling(self, path: str = ""): |
| """Load rolling buffer from disk.""" |
| path = path or str(Path(self.replay_path).parent / "buffer.jsonl") |
| if not Path(path).exists(): |
| return |
| self._rolling.clear() |
| with open(path) as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| ex = TrainingExample.from_dict(json.loads(line)) |
| self._rolling.append(ex) |
|
|
| def save_replay(self): |
| """Persist replay buffer to disk.""" |
| if not self.replay_path: |
| return |
| Path(self.replay_path).parent.mkdir(parents=True, exist_ok=True) |
| with open(self.replay_path, "w") as f: |
| for ex in self._replay: |
| f.write(json.dumps(ex.to_dict()) + "\n") |
|
|
| def _load_replay(self): |
| """Load replay buffer from disk.""" |
| if not self.replay_path or not Path(self.replay_path).exists(): |
| return |
| self._replay.clear() |
| with open(self.replay_path) as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| ex = TrainingExample.from_dict(json.loads(line)) |
| self._replay.append(ex) |
| |
| if len(self._replay) > self.replay_size: |
| self._replay = random.sample(self._replay, self.replay_size) |
|
|
| def clear(self): |
| """Clear all buffers (for reset).""" |
| self._rolling.clear() |
| self._replay.clear() |
| self._total_added = 0 |
|
|
| def stats(self) -> dict: |
| """Return buffer statistics.""" |
| return { |
| "rolling_count": self.rolling_count, |
| "rolling_capacity": self.rolling_size, |
| "replay_count": self.replay_count, |
| "replay_capacity": self.replay_size, |
| "total_added": self._total_added, |
| } |
|
|