adithimshrouthy commited on
Commit
7a517c1
·
verified ·
1 Parent(s): 267a356

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -16,9 +16,9 @@ summ = pipeline("summarization", model=SUM_MODEL)
16
  # === REQUEST SCHEMAS ===
17
  class ZSCReq(BaseModel):
18
  text: str
19
- labels: list[str]
20
  multi_label: bool = True
21
- template: str = "This text is about {}."
22
 
23
  class SAReq(BaseModel):
24
  text: str
@@ -29,6 +29,16 @@ class SumReq(BaseModel):
29
  min_length: int = 20
30
  do_sample: bool = False
31
 
 
 
 
 
 
 
 
 
 
 
32
  # === ROUTES ===
33
  @app.get("/")
34
  def health():
@@ -38,18 +48,22 @@ def health():
38
  # in your HF Space app.py
39
  @app.post("/predict")
40
  def predict(req: ZSCReq):
 
 
 
 
41
  out = zsc(
42
  req.text,
43
  candidate_labels=req.labels,
44
- multi_label=req.multi_label,
45
- hypothesis_template=req.template,
46
  )
47
  return {"labels": out["labels"], "scores": out["scores"]}
48
 
49
  @app.post("/sa")
50
  def sentiment(req: SAReq):
51
  result = sa(req.text)[0]
52
- return {"label": result["label"], "score": float(result["score"])}
53
 
54
  @app.post("/sum")
55
  def summarize(req: SumReq):
 
16
  # === REQUEST SCHEMAS ===
17
  class ZSCReq(BaseModel):
18
  text: str
19
+ labels: list[str] : []
20
  multi_label: bool = True
21
+ template: str = "This feedback is primarily about {}."
22
 
23
  class SAReq(BaseModel):
24
  text: str
 
29
  min_length: int = 20
30
  do_sample: bool = False
31
 
32
+ DEFAULT_LABELS = [
33
+ "Usability", "Performance", "Visual Design",
34
+ "Feedback", "Navigation", "Responsiveness",
35
+ ]
36
+
37
+ @app.on_event("startup")
38
+ def warmup():
39
+ # tiny warmup so first real call isn’t cold
40
+ _ = zsc("warmup", candidate_labels=DEFAULT_LABELS, multi_label=False)
41
+
42
  # === ROUTES ===
43
  @app.get("/")
44
  def health():
 
48
  # in your HF Space app.py
49
  @app.post("/predict")
50
  def predict(req: ZSCReq):
51
+ labels = [l for l in (req.labels or DEFAULT_LABELS) if l in DEFAULT_LABELS]
52
+ if not labels:
53
+ labels = DEFAULT_LABELS
54
+
55
  out = zsc(
56
  req.text,
57
  candidate_labels=req.labels,
58
+ multi_label=False,
59
+ hypothesis_template=req.template or "This feedback is primarily about {}.",
60
  )
61
  return {"labels": out["labels"], "scores": out["scores"]}
62
 
63
  @app.post("/sa")
64
  def sentiment(req: SAReq):
65
  result = sa(req.text)[0]
66
+ return {"label": result["label"], "scores": [float(s) for s in out["scores"]],}
67
 
68
  @app.post("/sum")
69
  def summarize(req: SumReq):