Spaces:
Sleeping
Sleeping
File size: 2,282 Bytes
f3e93f5 3946a18 f3e93f5 3946a18 f3e93f5 3946a18 401d729 bedbc46 931699c 5f1f5da e198698 3946a18 e15d2e8 5f1f5da f3e93f5 3946a18 f3e93f5 5f1f5da 7a517c1 931699c 3946a18 931699c f3e93f5 7210507 f3e93f5 5f1f5da f810aee 780a92a 5f1f5da 931699c 5f1f5da 931699c 5f1f5da 931699c 5f1f5da 931699c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict
from transformers import pipeline
import os
app = FastAPI()
# === MODELS (unchanged per your request) ===
# ZSC_MODEL="MoritzLaurer/deberta-v3-large-zeroshot-v2.0-c"
ZSC_MODEL = "facebook/bart-large-mnli"
SA_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
SUM_MODEL = "t5-small"
DEFAULT_LABELS = ["Usability","Performance","Visual Design","Feedback","Navigation","Responsiveness"]
TEMPLATE = "This feedback is primarily about {}."
zsc = pipeline("zero-shot-classification", model=ZSC_MODEL)
sa = pipeline("sentiment-analysis", model=SA_MODEL)
summ = pipeline("summarization", model=SUM_MODEL)
class ZSCReq(BaseModel):
text: str
labels: List[str] = []
multi_label: bool = False
template: str = TEMPLATE
class SAReq(BaseModel):
text: str
class SumReq(BaseModel):
text: str
max_length: int = 60
min_length: int = 20
do_sample: bool = False
@app.on_event("startup")
def warmup():
try:
_ = zsc("warmup", candidate_labels=DEFAULT_LABELS, multi_label=False, truncation=True)
except Exception as e:
print(f"[warmup] skipped: {e}")
@app.get("/")
def health():
return {"status": "ok", "model": ZSC_MODEL, "labels": DEFAULT_LABELS}
@app.post("/predict")
def predict(req: ZSCReq):
labels = req.labels or DEFAULT_LABELS
out = zsc(
req.text,
candidate_labels=labels,
multi_label=False, # single best label
hypothesis_template=(req.template or "This feedback is primarily about {}."),
truncation=True,
)
pairs = sorted(zip(out["labels"], out["scores"]), key=lambda p: float(p[1]), reverse=True)
if not pairs:
return {"labels": [], "scores": []}
lbls, scs = zip(*pairs)
return {"labels": list(lbls), "scores": [float(s) for s in scs]}
@app.post("/sa")
def sentiment(req: SAReq):
r = sa(req.text)[0]
return {"label": r["label"], "score": float(r["score"])}
@app.post("/sum")
def summarize(req: SumReq):
r = summ(
req.text,
max_length=req.max_length,
min_length=req.min_length,
do_sample=req.do_sample,
truncation=True,
)[0]
return {"summary_text": r["summary_text"]} |