| | from flask import Flask, render_template, request, flash, jsonify |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from huggingface_hub import login |
| | import os, json |
| |
|
| | app = Flask(__name__) |
| | app.secret_key = os.urandom(24) |
| |
|
| | ee_model = None |
| | ee_tokenizer = None |
| | ee_config = None |
| | ee_model_name = None |
| |
|
| | SPACE_HOST = os.environ.get("SPACE_HOST", "") |
| | SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860" |
| |
|
| |
|
| | @app.route("/", methods=["GET", "POST"]) |
| | def index(): |
| | global ee_model, ee_tokenizer, ee_config, ee_model_name |
| |
|
| | if request.method == "POST": |
| | ee_model_name = request.form["ee_model_name"].strip() |
| | hf_token = request.form["hf_token"].strip() |
| |
|
| | try: |
| | login(token=hf_token) |
| |
|
| | ee_model = AutoModelForCausalLM.from_pretrained( |
| | ee_model_name, torch_dtype=torch.float16, |
| | device_map="auto", trust_remote_code=True |
| | ) |
| | ee_tokenizer = AutoTokenizer.from_pretrained( |
| | ee_model_name, trust_remote_code=True |
| | ) |
| |
|
| | from huggingface_hub import hf_hub_download |
| | config_path = hf_hub_download(ee_model_name, "ee_config.json") |
| | with open(config_path) as f: |
| | ee_config = json.load(f) |
| |
|
| | flash(f"β
Model loaded: {ee_model_name}", "success") |
| | flash("Point your Client Space to this Space's URL below.", "info") |
| |
|
| | except Exception as e: |
| | flash(f"Error: {str(e)}", "danger") |
| |
|
| | return render_template( |
| | "index.html", |
| | server_ready=(ee_model is not None), |
| | model_name=ee_model_name if ee_config else None, |
| | space_url=SPACE_URL, |
| | ) |
| |
|
| |
|
| | @app.route("/generate", methods=["POST"]) |
| | def generate(): |
| | """ |
| | Receives sigma-encrypted embeddings + optional past_key_values. |
| | Returns last hidden state (still in sigma-space) + new KV cache. |
| | Does NOT run lm_head β that stays on the client. |
| | Server never sees token IDs, logits, or plaintext. |
| | """ |
| | if ee_model is None: |
| | return jsonify({"error": "Server not started yet"}), 400 |
| |
|
| | try: |
| | data = request.json |
| | model_dtype = next(ee_model.parameters()).dtype |
| |
|
| | inputs_embeds = torch.tensor(data["inputs_embeds"]).to( |
| | dtype=model_dtype, device=ee_model.device |
| | ) |
| |
|
| | attention_mask = torch.tensor( |
| | data.get("attention_mask", [[1] * inputs_embeds.shape[1]]) |
| | ).to(device=ee_model.device) |
| |
|
| | past_key_values = None |
| | if data.get("past_key_values"): |
| | past_key_values = tuple( |
| | tuple( |
| | torch.tensor(t).to(dtype=model_dtype, device=ee_model.device) |
| | for t in layer |
| | ) |
| | for layer in data["past_key_values"] |
| | ) |
| |
|
| | with torch.no_grad(): |
| | out = ee_model( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | use_cache=False, |
| | output_hidden_states=True, |
| | ) |
| |
|
| | |
| | last_hidden = out.hidden_states[-1] |
| |
|
| | return jsonify({ |
| | "last_hidden": last_hidden.cpu().tolist(), |
| | }) |
| |
|
| | except Exception as e: |
| | import traceback |
| | return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.run(host="0.0.0.0", port=7860) |