from flask import Flask, render_template, request, flash, jsonify import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login import os, json app = Flask(__name__) app.secret_key = os.urandom(24) ee_model = None ee_tokenizer = None ee_config = None ee_model_name = None SPACE_HOST = os.environ.get("SPACE_HOST", "") SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860" @app.route("/", methods=["GET", "POST"]) def index(): global ee_model, ee_tokenizer, ee_config, ee_model_name if request.method == "POST": ee_model_name = request.form["ee_model_name"].strip() hf_token = request.form["hf_token"].strip() try: login(token=hf_token) ee_model = AutoModelForCausalLM.from_pretrained( ee_model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) ee_tokenizer = AutoTokenizer.from_pretrained( ee_model_name, trust_remote_code=True ) from huggingface_hub import hf_hub_download config_path = hf_hub_download(ee_model_name, "ee_config.json") with open(config_path) as f: ee_config = json.load(f) flash(f"✅ Model loaded: {ee_model_name}", "success") flash("Point your Client Space to this Space's URL below.", "info") except Exception as e: flash(f"Error: {str(e)}", "danger") return render_template( "index.html", server_ready=(ee_model is not None), model_name=ee_model_name if ee_config else None, space_url=SPACE_URL, ) @app.route("/generate", methods=["POST"]) def generate(): """ Receives sigma-encrypted embeddings + optional past_key_values. Returns last hidden state (still in sigma-space) + new KV cache. Does NOT run lm_head — that stays on the client. Server never sees token IDs, logits, or plaintext. """ if ee_model is None: return jsonify({"error": "Server not started yet"}), 400 try: data = request.json model_dtype = next(ee_model.parameters()).dtype inputs_embeds = torch.tensor(data["inputs_embeds"]).to( dtype=model_dtype, device=ee_model.device ) attention_mask = torch.tensor( data.get("attention_mask", [[1] * inputs_embeds.shape[1]]) ).to(device=ee_model.device) past_key_values = None if data.get("past_key_values"): past_key_values = tuple( tuple( torch.tensor(t).to(dtype=model_dtype, device=ee_model.device) for t in layer ) for layer in data["past_key_values"] ) with torch.no_grad(): out = ee_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=False, output_hidden_states=True, ) # Return final hidden state in sigma-space — client applies sigma_inv + lm_head last_hidden = out.hidden_states[-1] # (1, seq_len, hidden) return jsonify({ "last_hidden": last_hidden.cpu().tolist(), }) except Exception as e: import traceback return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)