| |
| """ |
| Helion-2.5-Rnd Advanced Data Loader |
| Efficient data loading and preprocessing for inference |
| """ |
|
|
| import json |
| import logging |
| from pathlib import Path |
| from typing import Any, Dict, Iterator, List, Optional, Union |
|
|
| import numpy as np |
| from safetensors.torch import load_file |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class SafeTensorsLoader: |
| """Efficient SafeTensors model loading with validation""" |
| |
| def __init__(self, model_path: str, device: str = "cuda"): |
| """ |
| Initialize SafeTensors loader |
| |
| Args: |
| model_path: Path to model directory |
| device: Target device for loading |
| """ |
| self.model_path = Path(model_path) |
| self.device = device |
| self.index = self._load_index() |
| self.loaded_shards = {} |
| |
| def _load_index(self) -> Dict: |
| """Load SafeTensors index file""" |
| index_path = self.model_path / "model.safetensors.index.json" |
| |
| if not index_path.exists(): |
| raise FileNotFoundError(f"Index file not found: {index_path}") |
| |
| with open(index_path, 'r') as f: |
| index = json.load(f) |
| |
| logger.info(f"Loaded index with {len(index.get('weight_map', {}))} weight mappings") |
| return index |
| |
| def get_shard_path(self, shard_name: str) -> Path: |
| """Get full path to shard file""" |
| return self.model_path / shard_name |
| |
| def load_shard(self, shard_name: str, lazy: bool = False) -> Dict: |
| """ |
| Load a single SafeTensors shard |
| |
| Args: |
| shard_name: Name of shard file |
| lazy: Whether to use lazy loading |
| |
| Returns: |
| Dictionary of tensors |
| """ |
| if shard_name in self.loaded_shards: |
| logger.debug(f"Using cached shard: {shard_name}") |
| return self.loaded_shards[shard_name] |
| |
| shard_path = self.get_shard_path(shard_name) |
| |
| if not shard_path.exists(): |
| raise FileNotFoundError(f"Shard not found: {shard_path}") |
| |
| logger.info(f"Loading shard: {shard_name}") |
| |
| try: |
| tensors = load_file(str(shard_path), device=self.device) |
| |
| if not lazy: |
| self.loaded_shards[shard_name] = tensors |
| |
| return tensors |
| |
| except Exception as e: |
| logger.error(f"Failed to load shard {shard_name}: {e}") |
| raise |
| |
| def load_weight(self, weight_name: str) -> Any: |
| """ |
| Load a specific weight by name |
| |
| Args: |
| weight_name: Name of the weight tensor |
| |
| Returns: |
| Weight tensor |
| """ |
| weight_map = self.index.get('weight_map', {}) |
| |
| if weight_name not in weight_map: |
| raise KeyError(f"Weight not found in index: {weight_name}") |
| |
| shard_name = weight_map[weight_name] |
| tensors = self.load_shard(shard_name) |
| |
| return tensors[weight_name] |
| |
| def load_all_weights(self, progress_callback=None) -> Dict: |
| """ |
| Load all model weights |
| |
| Args: |
| progress_callback: Optional callback for progress updates |
| |
| Returns: |
| Dictionary of all weights |
| """ |
| all_weights = {} |
| weight_map = self.index.get('weight_map', {}) |
| unique_shards = set(weight_map.values()) |
| |
| logger.info(f"Loading {len(unique_shards)} shards...") |
| |
| for i, shard_name in enumerate(sorted(unique_shards)): |
| tensors = self.load_shard(shard_name) |
| all_weights.update(tensors) |
| |
| if progress_callback: |
| progress_callback(i + 1, len(unique_shards)) |
| |
| logger.info(f"Loaded {len(all_weights)} weight tensors") |
| return all_weights |
| |
| def validate_checksums(self) -> Dict[str, bool]: |
| """ |
| Validate SHA256 checksums of all shards |
| |
| Returns: |
| Dictionary mapping shard names to validation status |
| """ |
| import hashlib |
| |
| results = {} |
| file_metadata = self.index.get('file_metadata', {}) |
| |
| for shard_name, metadata in file_metadata.items(): |
| expected_hash = metadata.get('sha256') |
| |
| if not expected_hash: |
| results[shard_name] = None |
| continue |
| |
| shard_path = self.get_shard_path(shard_name) |
| |
| if not shard_path.exists(): |
| results[shard_name] = False |
| continue |
| |
| sha256 = hashlib.sha256() |
| with open(shard_path, 'rb') as f: |
| for chunk in iter(lambda: f.read(4096), b''): |
| sha256.update(chunk) |
| |
| actual_hash = sha256.hexdigest() |
| results[shard_name] = (actual_hash == expected_hash) |
| |
| status = "✓" if results[shard_name] else "✗" |
| logger.info(f"{status} {shard_name}") |
| |
| return results |
| |
| def get_model_info(self) -> Dict: |
| """Get model information from index""" |
| metadata = self.index.get('metadata', {}) |
| |
| return { |
| 'model_name': metadata.get('model_name', 'Unknown'), |
| 'version': metadata.get('version', 'Unknown'), |
| 'total_size_bytes': metadata.get('total_size', 0), |
| 'total_size_gb': metadata.get('total_size', 0) / (1024**3), |
| 'format': metadata.get('format', 'safetensors'), |
| 'precision': metadata.get('precision', 'unknown'), |
| 'total_shards': metadata.get('total_shards', 0), |
| 'parameters': metadata.get('parameters', 'Unknown') |
| } |
| |
| def clear_cache(self): |
| """Clear loaded shard cache""" |
| self.loaded_shards.clear() |
| logger.info("Cleared shard cache") |
|
|
|
|
| class DatasetPreprocessor: |
| """Preprocess datasets for inference""" |
| |
| def __init__(self, tokenizer=None, max_length: int = 131072): |
| """ |
| Initialize preprocessor |
| |
| Args: |
| tokenizer: Tokenizer instance |
| max_length: Maximum sequence length |
| """ |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| |
| def preprocess_text(self, text: str) -> str: |
| """ |
| Preprocess raw text |
| |
| Args: |
| text: Input text |
| |
| Returns: |
| Preprocessed text |
| """ |
| |
| text = ' '.join(text.split()) |
| |
| |
| text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t') |
| |
| return text.strip() |
| |
| def preprocess_chat_messages(self, messages: List[Dict[str, str]]) -> str: |
| """ |
| Preprocess chat messages into prompt format |
| |
| Args: |
| messages: List of message dictionaries |
| |
| Returns: |
| Formatted prompt string |
| """ |
| formatted = "" |
| |
| for msg in messages: |
| role = msg.get('role', 'user') |
| content = self.preprocess_text(msg.get('content', '')) |
| formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
| |
| formatted += "<|im_start|>assistant\n" |
| return formatted |
| |
| def batch_preprocess( |
| self, |
| texts: List[str], |
| add_special_tokens: bool = True, |
| padding: bool = True, |
| truncation: bool = True |
| ) -> Dict: |
| """ |
| Batch preprocess texts |
| |
| Args: |
| texts: List of input texts |
| add_special_tokens: Whether to add special tokens |
| padding: Whether to pad sequences |
| truncation: Whether to truncate sequences |
| |
| Returns: |
| Batch of preprocessed data |
| """ |
| if self.tokenizer is None: |
| raise ValueError("Tokenizer not initialized") |
| |
| processed_texts = [self.preprocess_text(text) for text in texts] |
| |
| encodings = self.tokenizer( |
| processed_texts, |
| add_special_tokens=add_special_tokens, |
| padding=padding, |
| truncation=truncation, |
| max_length=self.max_length, |
| return_tensors='pt' |
| ) |
| |
| return encodings |
| |
| def stream_process_file( |
| self, |
| file_path: str, |
| batch_size: int = 32 |
| ) -> Iterator[Dict]: |
| """ |
| Stream process large files in batches |
| |
| Args: |
| file_path: Path to input file |
| batch_size: Number of samples per batch |
| |
| Yields: |
| Batches of preprocessed data |
| """ |
| path = Path(file_path) |
| |
| if path.suffix == '.jsonl': |
| with open(path, 'r') as f: |
| batch = [] |
| |
| for line in f: |
| try: |
| data = json.loads(line) |
| text = data.get('text', '') |
| batch.append(text) |
| |
| if len(batch) >= batch_size: |
| yield self.batch_preprocess(batch) |
| batch = [] |
| |
| except json.JSONDecodeError: |
| logger.warning(f"Skipping invalid JSON line") |
| |
| if batch: |
| yield self.batch_preprocess(batch) |
| |
| elif path.suffix == '.txt': |
| with open(path, 'r') as f: |
| batch = [] |
| |
| for line in f: |
| batch.append(line.strip()) |
| |
| if len(batch) >= batch_size: |
| yield self.batch_preprocess(batch) |
| batch = [] |
| |
| if batch: |
| yield self.batch_preprocess(batch) |
| |
| else: |
| raise ValueError(f"Unsupported file format: {path.suffix}") |
|
|
|
|
| class InferenceDataCollator: |
| """Collate data for efficient batch inference""" |
| |
| def __init__(self, pad_token_id: int = 128001): |
| """ |
| Initialize data collator |
| |
| Args: |
| pad_token_id: ID for padding token |
| """ |
| self.pad_token_id = pad_token_id |
| |
| def __call__(self, features: List[Dict]) -> Dict: |
| """ |
| Collate features into batch |
| |
| Args: |
| features: List of feature dictionaries |
| |
| Returns: |
| Batched features |
| """ |
| if not features: |
| return {} |
| |
| |
| max_length = max(len(f['input_ids']) for f in features) |
| |
| batch = { |
| 'input_ids': [], |
| 'attention_mask': [] |
| } |
| |
| for feature in features: |
| input_ids = feature['input_ids'] |
| attention_mask = feature.get('attention_mask', [1] * len(input_ids)) |
| |
| |
| padding_length = max_length - len(input_ids) |
| |
| input_ids = input_ids + [self.pad_token_id] * padding_length |
| attention_mask = attention_mask + [0] * padding_length |
| |
| batch['input_ids'].append(input_ids) |
| batch['attention_mask'].append(attention_mask) |
| |
| |
| batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64) |
| batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64) |
| |
| return batch |
| |
| def dynamic_padding(self, features: List[Dict], padding_multiple: int = 8) -> Dict: |
| """ |
| Apply dynamic padding optimized for hardware |
| |
| Args: |
| features: List of feature dictionaries |
| padding_multiple: Pad to multiple of this value |
| |
| Returns: |
| Batched features with optimal padding |
| """ |
| if not features: |
| return {} |
| |
| max_length = max(len(f['input_ids']) for f in features) |
| |
| |
| padded_length = ((max_length + padding_multiple - 1) // padding_multiple) * padding_multiple |
| |
| batch = { |
| 'input_ids': [], |
| 'attention_mask': [] |
| } |
| |
| for feature in features: |
| input_ids = feature['input_ids'] |
| attention_mask = feature.get('attention_mask', [1] * len(input_ids)) |
| |
| padding_length = padded_length - len(input_ids) |
| |
| input_ids = input_ids + [self.pad_token_id] * padding_length |
| attention_mask = attention_mask + [0] * padding_length |
| |
| batch['input_ids'].append(input_ids) |
| batch['attention_mask'].append(attention_mask) |
| |
| batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64) |
| batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64) |
| |
| return batch |
|
|
|
|
| class CachedDataLoader: |
| """Data loader with caching for repeated inference""" |
| |
| def __init__(self, cache_dir: str = "./cache"): |
| """ |
| Initialize cached data loader |
| |
| Args: |
| cache_dir: Directory for cache storage |
| """ |
| self.cache_dir = Path(cache_dir) |
| self.cache_dir.mkdir(parents=True, exist_ok=True) |
| |
| def get_cache_key(self, text: str) -> str: |
| """Generate cache key from text""" |
| import hashlib |
| return hashlib.sha256(text.encode()).hexdigest() |
| |
| def load_from_cache(self, cache_key: str) -> Optional[Any]: |
| """ |
| Load data from cache |
| |
| Args: |
| cache_key: Cache identifier |
| |
| Returns: |
| Cached data or None |
| """ |
| cache_path = self.cache_dir / f"{cache_key}.json" |
| |
| if not cache_path.exists(): |
| return None |
| |
| try: |
| with open(cache_path, 'r') as f: |
| return json.load(f) |
| except Exception as e: |
| logger.warning(f"Failed to load from cache: {e}") |
| return None |
| |
| def save_to_cache(self, cache_key: str, data: Any): |
| """ |
| Save data to cache |
| |
| Args: |
| cache_key: Cache identifier |
| data: Data to cache |
| """ |
| cache_path = self.cache_dir / f"{cache_key}.json" |
| |
| try: |
| with open(cache_path, 'w') as f: |
| json.dump(data, f) |
| except Exception as e: |
| logger.warning(f"Failed to save to cache: {e}") |
| |
| def clear_cache(self): |
| """Clear all cached data""" |
| import shutil |
| shutil.rmtree(self.cache_dir) |
| self.cache_dir.mkdir(parents=True, exist_ok=True) |
| logger.info("Cache cleared") |
|
|
|
|
| def main(): |
| """Example usage""" |
| |
| loader = SafeTensorsLoader("./models/helion") |
| |
| |
| info = loader.get_model_info() |
| print(f"Model: {info['model_name']}") |
| print(f"Size: {info['total_size_gb']:.2f} GB") |
| print(f"Shards: {info['total_shards']}") |
| |
| |
| print("\nValidating checksums...") |
| results = loader.validate_checksums() |
| valid_count = sum(1 for v in results.values() if v) |
| print(f"Valid: {valid_count}/{len(results)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |