Maria Loskutova commited on
Commit
e7bb669
·
1 Parent(s): 7236fda
Files changed (5) hide show
  1. advice.py +176 -0
  2. app.py +179 -0
  3. common.py +173 -0
  4. receipt_total_api.py +70 -0
  5. requirements.txt +10 -0
advice.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ from common import read_json_stdin, write_json_stdout, current_month_snapshot, clean_ru
8
+
9
+ ALLOWED_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
10
+
11
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
12
+ os.environ.setdefault("MKL_NUM_THREADS", "1")
13
+ try:
14
+ torch.set_num_threads(1)
15
+ except Exception:
16
+ pass
17
+
18
+ _DEVICE = torch.device("cpu")
19
+ _tokenizer = None
20
+ _model = None
21
+ _loaded = False
22
+
23
+
24
+ def _load():
25
+ global _tokenizer, _model, _loaded
26
+ if _loaded and _tokenizer is not None and _model is not None:
27
+ return _tokenizer, _model
28
+
29
+ _tokenizer = AutoTokenizer.from_pretrained(
30
+ ALLOWED_MODEL_ID,
31
+ trust_remote_code=True,
32
+ )
33
+ _model = AutoModelForCausalLM.from_pretrained(
34
+ ALLOWED_MODEL_ID,
35
+ torch_dtype=torch.float32,
36
+ low_cpu_mem_usage=True,
37
+ trust_remote_code=True,
38
+ ).to(_DEVICE).eval()
39
+
40
+ if _tokenizer.pad_token_id is None:
41
+ _tokenizer.pad_token_id = _tokenizer.eos_token_id
42
+
43
+ _loaded = True
44
+ return _tokenizer, _model
45
+
46
+
47
+ def _gen(messages, tok, mdl, max_new_tokens=200, det=True):
48
+ txt = tok.apply_chat_template(
49
+ messages, tokenize=False, add_generation_prompt=True
50
+ )
51
+ inputs = tok(
52
+ txt,
53
+ return_tensors="pt",
54
+ padding=True,
55
+ truncation=True,
56
+ max_length=1400,
57
+ ).to(_DEVICE)
58
+
59
+ with torch.no_grad():
60
+ common = dict(
61
+ max_new_tokens=max_new_tokens,
62
+ repetition_penalty=1.08 if det else 1.12,
63
+ no_repeat_ngram_size=5 if det else 6,
64
+ eos_token_id=tok.eos_token_id,
65
+ pad_token_id=tok.pad_token_id,
66
+ )
67
+ if det:
68
+ out = mdl.generate(
69
+ **inputs,
70
+ do_sample=False,
71
+ num_beams=4,
72
+ **common,
73
+ )
74
+ else:
75
+ out = mdl.generate(
76
+ **inputs,
77
+ do_sample=True,
78
+ temperature=0.8,
79
+ top_p=0.9,
80
+ top_k=50,
81
+ **common,
82
+ )
83
+ return tok.decode(out[0], skip_special_tokens=True)
84
+
85
+
86
+ _BULLET_KILL = re.compile(
87
+ r"(?i)(учитывай данные|данные пользователя|месяц:|доход:|расход:|нетто:|топ стат|вопрос:|assistant)"
88
+ )
89
+ _ONLY_PUNCT = re.compile(r"^[-•\s\.\,\;\:\!\?]+$")
90
+
91
+
92
+ def _to_bullets(text: str) -> str:
93
+ if not text:
94
+ return ""
95
+ m = re.search(r"(\n\s*[-*]\s+|\n\s*\d+[\).\s]+|•)", "\n" + text)
96
+ if m:
97
+ text = text[m.start() :]
98
+
99
+ text = re.sub(r"^\s*[*•]\s+", "- ", text, flags=re.M)
100
+ text = re.sub(r"^\s*\d+[\).\s]+", "- ", text, flags=re.M)
101
+
102
+ uniq, seen = [], set()
103
+ for ln in text.split("\n"):
104
+ s = ln.strip()
105
+ if not s or not s.startswith("- "):
106
+ continue
107
+ if _BULLET_KILL.search(s) or _ONLY_PUNCT.match(s):
108
+ continue
109
+ s = re.sub(r"\s{2,}", " ", s)
110
+ s = re.sub(r"\.\s*\.+$", ".", s)
111
+ key = s.lower()
112
+ if key in seen:
113
+ continue
114
+ seen.add(key)
115
+ uniq.append(s)
116
+ if len(uniq) >= 7:
117
+ break
118
+
119
+ return "\n".join(s.replace("- ", "• ", 1) for s in uniq)
120
+
121
+
122
+ def main():
123
+ req = read_json_stdin()
124
+
125
+ tx = req.get("transactions") or []
126
+ question = (req.get("question") or "").strip()
127
+
128
+ df = pd.DataFrame(tx) if tx else None
129
+ snap = current_month_snapshot(df) if df is not None and not df.empty else {}
130
+
131
+ if snap:
132
+ ctx = [
133
+ f"Месяц: {snap['month']}",
134
+ f"Доход: {snap['income_total']:.0f}",
135
+ f"Расход: {abs(snap['expense_total']):.0f}",
136
+ f"Нетто: {snap['net']:.0f}",
137
+ ]
138
+ if snap.get("top_expense_categories"):
139
+ ctx.append("Топ статей расходов:")
140
+ for cat, val in snap["top_expense_categories"]:
141
+ ctx.append(f"- {cat}: {abs(val):.0f}")
142
+ context = "\n".join(ctx)
143
+ else:
144
+ context = "Данных за текущий месяц нет."
145
+
146
+ system_msg = (
147
+ "Ты финансовый помощник. Отвечай по-русски. "
148
+ "Верни ТОЛЬКО список из 5–7 конкретных шагов экономии с цифрами (лимиты, проценты, частота). "
149
+ "Каждая строка должна начинаться с символов \"- \". Никаких вступлений."
150
+ )
151
+ messages = [
152
+ {"role": "system", "content": system_msg},
153
+ {
154
+ "role": "user",
155
+ "content": (
156
+ f"Мои данные за текущий месяц:\n{context}\n\nВопрос: {question}\n"
157
+ "Начни ответ сразу со строки, которая начинается с \"- \". Верни только список."
158
+ ),
159
+ },
160
+ ]
161
+
162
+ tok, mdl = _load()
163
+ raw = _gen(messages, tok, mdl, det=True)
164
+ text = _to_bullets(clean_ru(raw))
165
+
166
+ if text.count("\n") + 1 < 3:
167
+ raw2 = _gen(messages, tok, mdl, det=False)
168
+ text2 = _to_bullets(clean_ru(raw2))
169
+ if text2:
170
+ text = text2
171
+
172
+ write_json_stdout({"advice": text})
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import tempfile
4
+ from decimal import Decimal
5
+ from typing import List, Optional
6
+
7
+ import pandas as pd
8
+ from fastapi import FastAPI, HTTPException, UploadFile, File
9
+ from pydantic import BaseModel
10
+
11
+ from common import prepare_components_series, fit_and_forecast, current_month_snapshot
12
+ from advice import _load as advice_load, _gen as advice_gen, _to_bullets, clean_ru
13
+ from receipt_total_api import extract_total
14
+
15
+ app = FastAPI()
16
+
17
+ # ---------- Pydantic-модели под Go-структуры ----------
18
+
19
+
20
+ class Transaction(BaseModel):
21
+ date: str
22
+ amount: Decimal
23
+ type: str
24
+ category: Optional[str] = None
25
+ description: Optional[str] = None
26
+
27
+
28
+ class ForecastRequest(BaseModel):
29
+ granularity: str
30
+ steps: int
31
+ model: Optional[str] = None
32
+ transactions: List[Transaction]
33
+
34
+
35
+ class ForecastResponse(BaseModel):
36
+ period_end: List[str]
37
+ income_forecast: List[float]
38
+ expense_forecast: List[float]
39
+
40
+
41
+ class AdviceRequest(BaseModel):
42
+ question: Optional[str] = None
43
+ transactions: List[Transaction] = []
44
+
45
+
46
+ class AdviceResponse(BaseModel):
47
+ advice: str
48
+
49
+
50
+ class ReceiptResponse(BaseModel):
51
+ total: Optional[float]
52
+
53
+
54
+ # ---------- Стартуем и заранее грузим модели ----------
55
+
56
+ advice_tokenizer = None
57
+ advice_model = None
58
+
59
+
60
+ @app.on_event("startup")
61
+ def load_models():
62
+ """
63
+ Загружаем Qwen один раз при старте сервиса.
64
+ Donut для чеков грузится в receipt_total_api при первом импорте.
65
+ """
66
+ global advice_tokenizer, advice_model
67
+ advice_tokenizer, advice_model = advice_load()
68
+
69
+
70
+ # ---------- Эндпоинты ----------
71
+
72
+
73
+ @app.post("/forecast", response_model=ForecastResponse)
74
+ def forecast(req: ForecastRequest):
75
+ if not req.transactions:
76
+ raise HTTPException(status_code=400, detail="transactions is empty")
77
+
78
+ df = pd.DataFrame([t.dict() for t in req.transactions])
79
+
80
+ gran = (req.granularity or "month").lower()
81
+ freq = "A-DEC" if gran.startswith("y") else "M"
82
+ steps = int(req.steps or 1)
83
+ method = (req.model or "auto").lower()
84
+
85
+ inc, exp, _ = prepare_components_series(df, freq=freq)
86
+ inc_fc = fit_and_forecast(inc, steps, freq, method=method)
87
+ exp_fc = fit_and_forecast(exp, steps, freq, method=method)
88
+
89
+ return ForecastResponse(
90
+ period_end=[
91
+ d.strftime("%Y-%m-%d") for d in inc_fc.index.to_pydatetime().tolist()
92
+ ],
93
+ income_forecast=[float(x) for x in inc_fc.values.tolist()],
94
+ expense_forecast=[float(x) for x in exp_fc.values.tolist()],
95
+ )
96
+
97
+
98
+ @app.post("/advice", response_model=AdviceResponse)
99
+ def advice(req: AdviceRequest):
100
+ tx = [t.dict() for t in req.transactions] if req.transactions else []
101
+ df = pd.DataFrame(tx) if tx else None
102
+ snap = current_month_snapshot(df) if df is not None and not df.empty else {}
103
+
104
+ if snap:
105
+ ctx = [
106
+ f"Месяц: {snap['month']}",
107
+ f"Доход: {snap['income_total']:.0f}",
108
+ f"Расход: {abs(snap['expense_total']):.0f}",
109
+ f"Нетто: {snap['net']:.0f}",
110
+ ]
111
+ if snap.get("top_expense_categories"):
112
+ ctx.append("Топ статей расходов:")
113
+ for cat, val in snap["top_expense_categories"]:
114
+ ctx.append(f"- {cat}: {abs(val):.0f}")
115
+ context = "\n".join(ctx)
116
+ else:
117
+ context = "Данных за текущий месяц нет."
118
+
119
+ question = (req.question or "").strip()
120
+
121
+ system_msg = (
122
+ "Ты финансовый помощник. Отвечай по-русски. "
123
+ "Верни ТОЛЬКО список из 5–7 конкретных шагов экономии с цифрами (лимиты, проценты, частота). "
124
+ "Каждая строка должна начинаться с символов \"- \". Никаких вступлений."
125
+ )
126
+ messages = [
127
+ {"role": "system", "content": system_msg},
128
+ {
129
+ "role": "user",
130
+ "content": (
131
+ f"Мои данные за текущий месяц:\n{context}\n\nВопрос: {question}\n"
132
+ 'Начни ответ сразу со строки, которая начинается с "- ". Верни только список.'
133
+ ),
134
+ },
135
+ ]
136
+
137
+ raw = advice_gen(messages, advice_tokenizer, advice_model, det=True)
138
+ text = _to_bullets(clean_ru(raw))
139
+
140
+ # fallback на стохастический прогон, если мало пунктов
141
+ from advice import _gen as advice_gen2, _to_bullets as to_bullets2
142
+
143
+ if text.count("\n") + 1 < 3:
144
+ raw2 = advice_gen2(messages, advice_tokenizer, advice_model, det=False)
145
+ text2 = to_bullets2(clean_ru(raw2))
146
+ if text2:
147
+ text = text2
148
+
149
+ return AdviceResponse(advice=text)
150
+
151
+
152
+ @app.post("/receipt-total-file", response_model=ReceiptResponse)
153
+ async def receipt_total_file(file: UploadFile = File(...)):
154
+ """
155
+ Получает файл чека (multipart/form-data, field "file"),
156
+ сохраняет во временный файл, вызывает extract_total и возвращает сумму.
157
+ """
158
+ # сохраняем во временный файл
159
+ suffix = os.path.splitext(file.filename or "")[1] or ".jpg"
160
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
161
+ contents = await file.read()
162
+ tmp.write(contents)
163
+ tmp_path = tmp.name
164
+
165
+ try:
166
+ total = extract_total(tmp_path)
167
+ return ReceiptResponse(total=total)
168
+ finally:
169
+ try:
170
+ os.remove(tmp_path)
171
+ except OSError:
172
+ pass
173
+
174
+
175
+ # Для локального запуска, на HF Space это не используется
176
+ if __name__ == "__main__":
177
+ import uvicorn
178
+
179
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, workers=1)
common.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import Optional, Tuple
6
+ from statsmodels.tsa.holtwinters import ExponentialSmoothing, Holt
7
+
8
+ try:
9
+ from prophet import Prophet
10
+ _HAS_PROPHET = True
11
+ except Exception:
12
+ _HAS_PROPHET = False
13
+
14
+ _KEEP = re.compile(r"[^А-Яа-яЁё0-9 ,.!?:;()«»\"'–—\-•\n]")
15
+
16
+ def clean_ru(text: str) -> str:
17
+ text = _KEEP.sub(" ", text or "")
18
+ return re.sub(r"\s+", " ", text).strip()
19
+
20
+ def normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
21
+ work = df.copy()
22
+ for col in list(work.columns):
23
+ lc = col.lower()
24
+ if lc in ("date", "дата"):
25
+ work.rename(columns={col: "date"}, inplace=True)
26
+ elif lc in ("amount", "сумма"):
27
+ work.rename(columns={col: "amount"}, inplace=True)
28
+ elif lc in ("category", "категория"):
29
+ work.rename(columns={col: "category"}, inplace=True)
30
+ elif lc in ("type", "тип"):
31
+ work.rename(columns={col: "type"}, inplace=True)
32
+ required = {"date", "amount", "type"}
33
+ missing = required - set(map(str, work.columns))
34
+ if missing:
35
+ raise ValueError(f"Отсутствуют колонки: {', '.join(sorted(missing))}")
36
+ work["date"] = pd.to_datetime(work["date"], errors="coerce")
37
+ work = work.dropna(subset=["date"])
38
+ work["amount"] = pd.to_numeric(work["amount"], errors="coerce").fillna(0.0)
39
+ if "category" not in work.columns:
40
+ work["category"] = "Без категории"
41
+ return work
42
+
43
+ def is_expense(t: str) -> bool:
44
+ t = str(t).strip().lower()
45
+ return t in {"expense", "расход", "расходы", "-", "e", "exp"}
46
+
47
+ def is_income(t: str) -> bool:
48
+ t = str(t).strip().lower()
49
+ return t in {"income", "доход", "+", "i", "inc"}
50
+
51
+ def prepare_components_series(df: pd.DataFrame, freq: str="M") -> Tuple[pd.Series, pd.Series, pd.Series]:
52
+ if df is None or df.empty:
53
+ raise ValueError("Пустая таблица транзакций.")
54
+ work = normalize_columns(df)
55
+ work["is_expense"] = work["type"].apply(is_expense)
56
+ work["is_income"] = work["type"].apply(is_income)
57
+
58
+ inc = work.loc[work["is_income"]].set_index("date")["amount"].resample(freq).sum().sort_index()
59
+ exp = work.loc[work["is_expense"]].set_index("date")["amount"].abs().mul(-1).resample(freq).sum().sort_index()
60
+
61
+ if not inc.empty or not exp.empty:
62
+ start = min([x.index.min() for x in [inc, exp] if not x.empty])
63
+ end = max([x.index.max() for x in [inc, exp] if not x.empty])
64
+ full_idx = pd.date_range(start, end, freq=freq)
65
+ inc = inc.reindex(full_idx, fill_value=0.0)
66
+ exp = exp.reindex(full_idx, fill_value=0.0)
67
+ net = inc + exp
68
+ inc.index.name = exp.index.name = net.index.name = "period_end"
69
+ return inc, exp, net
70
+
71
+ def fit_and_forecast(history: pd.Series, steps: int, freq: str, method: str = "auto") -> pd.Series:
72
+ if len(history) < 3:
73
+ last = float(history.iloc[-1]) if len(history) else 0.0
74
+ start = (history.index[-1] if len(history) else pd.Timestamp.today().normalize()) + \
75
+ pd.tseries.frequencies.to_offset(freq)
76
+ idx = pd.date_range(start, periods=steps, freq=freq)
77
+ return pd.Series([last] * steps, index=idx, name="forecast")
78
+
79
+ use_prophet = False
80
+ if method == "prophet":
81
+ use_prophet = True
82
+ elif method == "auto":
83
+ if freq.startswith("A"):
84
+ use_prophet = _HAS_PROPHET and (len(history) >= 5)
85
+ else:
86
+ use_prophet = _HAS_PROPHET and (len(history) >= 18)
87
+
88
+ if use_prophet:
89
+ try:
90
+ pfreq = "Y" if freq.startswith("A") else "M"
91
+ dfp = history.reset_index()
92
+ dfp.columns = ["ds", "y"]
93
+
94
+ m = Prophet(
95
+ yearly_seasonality=(pfreq == "M"),
96
+ weekly_seasonality=False,
97
+ daily_seasonality=False,
98
+ seasonality_mode="additive",
99
+ )
100
+ m.fit(dfp)
101
+ future = m.make_future_dataframe(periods=steps, freq=pfreq)
102
+ fcst = m.predict(future).tail(steps)
103
+ yhat = pd.Series(fcst["yhat"].values, index=pd.DatetimeIndex(fcst["ds"]), name="forecast")
104
+
105
+ if pfreq == "M":
106
+ yhat.index = yhat.index.to_period("M").to_timestamp(how="end")
107
+ else:
108
+ yhat.index = yhat.index.to_period("Y").to_timestamp(how="end")
109
+
110
+ if yhat.index.freq is None:
111
+ yhat.index = pd.date_range(yhat.index[0], periods=len(yhat), freq=("A-DEC" if pfreq == "Y" else "M"))
112
+ return yhat
113
+ except Exception:
114
+ pass
115
+
116
+ try:
117
+ if freq.startswith("A"):
118
+ model = Holt(history, initialization_method="estimated")
119
+ else:
120
+ if len(history) >= 24:
121
+ model = ExponentialSmoothing(
122
+ history, trend="add", seasonal="add", seasonal_periods=12,
123
+ initialization_method="estimated"
124
+ )
125
+ else:
126
+ model = Holt(history, initialization_method="estimated")
127
+ fit = model.fit(optimized=True)
128
+ fc = fit.forecast(steps)
129
+ if not isinstance(fc.index, pd.DatetimeIndex) or len(fc.index) != steps:
130
+ start = history.index[-1] + pd.tseries.frequencies.to_offset(freq)
131
+ idx = pd.date_range(start, periods=steps, freq=freq)
132
+ fc = pd.Series(np.asarray(fc), index=idx, name="forecast")
133
+ return fc
134
+ except Exception:
135
+ tail = min(6, len(history))
136
+ baseline = float(history.tail(tail).mean()) if tail else 0.0
137
+ start = history.index[-1] + pd.tseries.frequencies.to_offset(freq)
138
+ idx = pd.date_range(start, periods=steps, freq=freq)
139
+ return pd.Series([baseline] * steps, index=idx, name="forecast")
140
+
141
+ def current_month_snapshot(df: pd.DataFrame) -> dict:
142
+ if df is None or df.empty:
143
+ return {}
144
+ w = normalize_columns(df)
145
+ w["is_income"] = w["type"].apply(is_income)
146
+ w["is_expense"] = w["type"].apply(is_expense)
147
+ lastp = w["date"].dt.to_period("M").max()
148
+ cur = w[w["date"].dt.to_period("M") == lastp].copy()
149
+ if cur.empty:
150
+ return {}
151
+ income_total = float(cur.loc[cur["is_income"], "amount"].sum())
152
+ expense_total = -float(cur.loc[cur["is_expense"], "amount"].abs().sum())
153
+ net = income_total + expense_total
154
+ exp_df = cur.loc[cur["is_expense"], ["category","amount"]].copy()
155
+ exp_df["amount"] = -exp_df["amount"].abs()
156
+ top = exp_df.groupby("category")["amount"].sum().sort_values().head(5)
157
+ return {
158
+ "month": str(lastp),
159
+ "income_total": income_total,
160
+ "expense_total": expense_total,
161
+ "net": net,
162
+ "top_expense_categories": [(str(k), float(v)) for k,v in top.items()]
163
+ }
164
+
165
+ def read_json_stdin() -> dict:
166
+ import sys
167
+ raw = sys.stdin.read()
168
+ return json.loads(raw or "{}")
169
+
170
+ def write_json_stdout(obj) -> None:
171
+ import sys
172
+ sys.stdout.write(json.dumps(obj, ensure_ascii=False))
173
+ sys.stdout.flush()
receipt_total_api.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import json
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
9
+
10
+ #os.environ["TRANSFORMERS_CACHE"] = "/tmp"
11
+ #os.environ["HF_HOME"] = "/tmp"
12
+
13
+ MODEL_ID = "naver-clova-ix/donut-base-finetuned-cord-v2"
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ processor = DonutProcessor.from_pretrained(MODEL_ID)
17
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
18
+
19
+ def pick_total_from_text(text: str):
20
+ if not text:
21
+ return None
22
+ text = text.replace("\xa0", " ")
23
+ def _to_float(s):
24
+ s = s.replace(" ", "").replace(",", ".")
25
+ try: return float(s)
26
+ except: return None
27
+ eq_matches = re.findall(r"=\s*(-?\d{1,3}(?:[ .,\u00A0]?\d{3})*(?:[.,]\d{2}))", text)
28
+ for m in reversed(eq_matches):
29
+ v = _to_float(m)
30
+ if v and v > 0: return v
31
+ matches = re.findall(r"(-?\d{1,3}(?:[ .,\u00A0]?\d{3})*(?:[.,]\d{2}))", text)
32
+ best = None
33
+ for m in matches:
34
+ v = _to_float(m)
35
+ if v and 0 < v < 1e6:
36
+ best = v
37
+ return best
38
+
39
+ def extract_total(image_path: str):
40
+ image = Image.open(image_path).convert("RGB")
41
+ task_prompt = "<s_cord-v2>"
42
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
43
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
44
+ outputs = model.generate(
45
+ pixel_values,
46
+ decoder_input_ids=decoder_input_ids,
47
+ max_length=model.config.decoder.max_position_embeddings,
48
+ early_stopping=True,
49
+ pad_token_id=processor.tokenizer.pad_token_id,
50
+ eos_token_id=processor.tokenizer.eos_token_id,
51
+ use_cache=True,
52
+ num_beams=1,
53
+ )
54
+ seq = processor.batch_decode(outputs, skip_special_tokens=True)[0]
55
+ seq = seq.replace(task_prompt, "").replace("<s>", "").replace("</s>", "").strip()
56
+ try:
57
+ data = json.loads(seq)
58
+ for k in ["total", "total_price", "grand_total"]:
59
+ if k in data:
60
+ return float(str(data[k]).replace(",", "."))
61
+ except Exception:
62
+ pass
63
+ return pick_total_from_text(seq)
64
+
65
+ if __name__ == "__main__":
66
+ if len(sys.argv) < 2:
67
+ print("Usage: receipt_total_api.py path/to/receipt.jpg", file=sys.stderr)
68
+ sys.exit(1)
69
+ total = extract_total(sys.argv[1])
70
+ print(total if total is not None else "null")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pydantic
4
+ pandas
5
+ numpy
6
+ torch
7
+ transformers
8
+ statsmodels
9
+ pillow
10
+ prophet