File size: 2,282 Bytes
f3e93f5
 
3946a18
f3e93f5
3946a18
f3e93f5
 
 
3946a18
401d729
bedbc46
931699c
5f1f5da
 
e198698
3946a18
 
 
 
 
e15d2e8
5f1f5da
f3e93f5
3946a18
 
 
f3e93f5
5f1f5da
 
 
 
 
 
 
 
 
7a517c1
 
931699c
3946a18
931699c
 
 
f3e93f5
 
7210507
f3e93f5
 
5f1f5da
f810aee
780a92a
 
 
 
 
 
 
 
 
 
 
 
5f1f5da
 
 
931699c
 
5f1f5da
 
 
931699c
5f1f5da
 
 
931699c
 
5f1f5da
931699c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict
from transformers import pipeline
import os

app = FastAPI()

# === MODELS (unchanged per your request) ===
# ZSC_MODEL="MoritzLaurer/deberta-v3-large-zeroshot-v2.0-c"
ZSC_MODEL = "facebook/bart-large-mnli"
SA_MODEL  = "distilbert-base-uncased-finetuned-sst-2-english"
SUM_MODEL = "t5-small"

DEFAULT_LABELS = ["Usability","Performance","Visual Design","Feedback","Navigation","Responsiveness"]
TEMPLATE = "This feedback is primarily about {}."

zsc  = pipeline("zero-shot-classification", model=ZSC_MODEL)
sa   = pipeline("sentiment-analysis",       model=SA_MODEL)
summ = pipeline("summarization",            model=SUM_MODEL)

class ZSCReq(BaseModel):
    text: str
    labels: List[str] = []
    multi_label: bool = False
    template: str = TEMPLATE

class SAReq(BaseModel):
    text: str

class SumReq(BaseModel):
    text: str
    max_length: int = 60
    min_length: int = 20
    do_sample: bool = False

@app.on_event("startup")
def warmup():
    try:
        _ = zsc("warmup", candidate_labels=DEFAULT_LABELS, multi_label=False, truncation=True)
    except Exception as e:
        print(f"[warmup] skipped: {e}")

@app.get("/")
def health():
    return {"status": "ok", "model": ZSC_MODEL, "labels": DEFAULT_LABELS}

@app.post("/predict")
def predict(req: ZSCReq):
    labels = req.labels or DEFAULT_LABELS
    out = zsc(
        req.text,
        candidate_labels=labels,
        multi_label=False,  # single best label
        hypothesis_template=(req.template or "This feedback is primarily about {}."),
        truncation=True,
    )
    pairs = sorted(zip(out["labels"], out["scores"]), key=lambda p: float(p[1]), reverse=True)
    if not pairs:
        return {"labels": [], "scores": []}
    lbls, scs = zip(*pairs)
    return {"labels": list(lbls), "scores": [float(s) for s in scs]}

@app.post("/sa")
def sentiment(req: SAReq):
    r = sa(req.text)[0]
    return {"label": r["label"], "score": float(r["score"])}

@app.post("/sum")
def summarize(req: SumReq):
    r = summ(
        req.text,
        max_length=req.max_length,
        min_length=req.min_length,
        do_sample=req.do_sample,
        truncation=True,
    )[0]
    return {"summary_text": r["summary_text"]}