Sana42's picture
FastAPI zero-shot classification with DeBERTa-v3-xsmall
ef2effd
raw
history blame contribute delete
741 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
app = FastAPI(title="Complaint Classifier")
# Model for zero-shot classification
MODEL_NAME = "MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary"
print(f"Loading model {MODEL_NAME} from Hugging Face Hub...")
classifier = pipeline("zero-shot-classification", model=MODEL_NAME)
class Request(BaseModel):
text: str
labels: list[str]
@app.post("/classify")
def classify(req: Request):
result = classifier(req.text, candidate_labels=req.labels)
return {
"label": result["labels"][0],
"score": float(result["scores"][0]),
"all": list(zip(result["labels"], [float(s) for s in result["scores"]]))
}