zzejiao commited on
Commit
c18fc4e
·
1 Parent(s): e15bf8b

update k number to 6

Browse files
Files changed (2) hide show
  1. src/Rag.py +1 -3
  2. src/app.py +0 -1
src/Rag.py CHANGED
@@ -17,7 +17,6 @@ embedder = None
17
  index = None
18
  llm_client = None
19
 
20
- # 添加线程锁以确保多用户并发安全
21
  import threading
22
  _model_lock = threading.Lock()
23
 
@@ -41,7 +40,6 @@ def save_embeddings(embedder_name, embeddings):
41
  Save embeddings to a .npy file.
42
  """
43
  file_path = os.path.join("data", "embeddings", f"{embedder_name.replace('/', '_')}.npy")
44
- # 确保目录存在
45
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
46
  np.save(file_path, embeddings)
47
  print(f"Saved embeddings for {embedder_name}...")
@@ -360,7 +358,7 @@ def depression_assistant(query, model_name="meta-llama/Llama-3.3-70B-Instruct-Tu
360
  global db, referenced_tables_db, embedder, index, llm_client
361
 
362
  t1 = time.perf_counter()
363
- results = faiss_search(query, embedder, db, index, referenced_tables_db, k=3)
364
  t2 = time.perf_counter()
365
  print(f"[Time] FAISS search done in {t2 - t1:.2f} seconds.")
366
 
 
17
  index = None
18
  llm_client = None
19
 
 
20
  import threading
21
  _model_lock = threading.Lock()
22
 
 
40
  Save embeddings to a .npy file.
41
  """
42
  file_path = os.path.join("data", "embeddings", f"{embedder_name.replace('/', '_')}.npy")
 
43
  os.makedirs(os.path.dirname(file_path), exist_ok=True)
44
  np.save(file_path, embeddings)
45
  print(f"Saved embeddings for {embedder_name}...")
 
358
  global db, referenced_tables_db, embedder, index, llm_client
359
 
360
  t1 = time.perf_counter()
361
+ results = faiss_search(query, embedder, db, index, referenced_tables_db)
362
  t2 = time.perf_counter()
363
  print(f"[Time] FAISS search done in {t2 - t1:.2f} seconds.")
364
 
src/app.py CHANGED
@@ -32,7 +32,6 @@ def get_llm_client(client_type, api_key):
32
  )
33
  return None
34
 
35
- # ===== 辅助函数 =====
36
  def get_current_model_info():
37
  if "cached_model_info" in st.session_state and st.session_state.cached_model_info:
38
  return st.session_state.cached_model_info
 
32
  )
33
  return None
34
 
 
35
  def get_current_model_info():
36
  if "cached_model_info" in st.session_state and st.session_state.cached_model_info:
37
  return st.session_state.cached_model_info