mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-30 15:16:45 +08:00
Add dataset with table parser type for Infinity and answer question in chat using SQL (#12541)
### What problem does this PR solve? 1) Create dataset using table parser for infinity 2) Answer questions in chat using SQL ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -37,7 +37,6 @@ from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.advanced_rag import DeepResearcher
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
@ -274,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
|
||||
|
||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
logging.debug("Begin async_chat")
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
async for ans in async_chat_solo(dialog, messages, stream):
|
||||
@ -323,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
logging.debug(f"field_map retrieved: {field_map}")
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
|
||||
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
|
||||
yield ans
|
||||
return
|
||||
else:
|
||||
logging.debug("SQL failed or returned no results, falling back to vector search")
|
||||
|
||||
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
|
||||
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
@ -366,7 +373,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
knowledges = []
|
||||
|
||||
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
|
||||
if attachments is not None and "knowledge" in param_keys:
|
||||
logging.debug("Proceeding with retrieval")
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
knowledges = []
|
||||
if prompt_config.get("reasoning", False):
|
||||
@ -575,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
|
||||
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||
Ensure that:
|
||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
||||
2. Write only the SQL, no explanations or additional text.
|
||||
"""
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
{}
|
||||
logging.debug(f"use_sql: Question: {question}")
|
||||
|
||||
Question are as follows:
|
||||
# Determine which document engine we're using
|
||||
doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es"
|
||||
|
||||
# Construct the full table name
|
||||
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
|
||||
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
|
||||
base_table = index_name(tenant_id)
|
||||
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
|
||||
# Infinity: append kb_id to table name
|
||||
table_name = f"{base_table}_{kb_ids[0]}"
|
||||
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
|
||||
else:
|
||||
# Elasticsearch/OpenSearch: use base index name
|
||||
table_name = base_table
|
||||
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
|
||||
|
||||
# Generate engine-specific SQL prompts
|
||||
if doc_engine == "infinity":
|
||||
# Build Infinity prompts with JSON extraction context
|
||||
json_field_names = list(field_map.keys())
|
||||
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||||
|
||||
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||||
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
|
||||
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
|
||||
|
||||
RULES:
|
||||
1. Use EXACT field names (case-sensitive) from the list below
|
||||
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
|
||||
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
|
||||
4. Add AS alias for extracted field names
|
||||
5. DO NOT select 'content' field
|
||||
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
|
||||
- Question asks to "show me" or "display" specific columns
|
||||
- Question mentions "not null" or "excluding null"
|
||||
- Add NULL check for count specific column
|
||||
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
|
||||
7. Output ONLY the SQL, no explanations"""
|
||||
user_prompt = """Table: {}
|
||||
Fields (EXACT case): {}
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||
Question: {}
|
||||
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
|
||||
table_name,
|
||||
", ".join(json_field_names),
|
||||
"\n".join([f" - {field}" for field in json_field_names]),
|
||||
question
|
||||
)
|
||||
else:
|
||||
# Build ES/OS prompts with direct field access
|
||||
sys_prompt = """You are a Database Administrator. Write SQL queries.
|
||||
|
||||
RULES:
|
||||
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
|
||||
2. Quote field names starting with digit: "123_field"
|
||||
3. Add IS NOT NULL in WHERE clause when:
|
||||
- Question asks to "show me" or "display" specific columns
|
||||
4. Include doc_id/docnm in non-aggregate statement
|
||||
5. Output ONLY the SQL, no explanations"""
|
||||
user_prompt = """Table: {}
|
||||
Available fields:
|
||||
{}
|
||||
Question: {}
|
||||
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
|
||||
table_name,
|
||||
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
|
||||
question
|
||||
)
|
||||
|
||||
tried_times = 0
|
||||
|
||||
async def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
sql = re.sub(r"&", "and", sql)
|
||||
if sql[: len("select ")] != "select ":
|
||||
return None, None
|
||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||
if sql[: len("select *")] != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:
|
||||
continue
|
||||
if len(flds) > 11:
|
||||
break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||||
# Remove think blocks if present (format: </think>...)
|
||||
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||||
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
|
||||
# Remove markdown code blocks (```sql ... ```)
|
||||
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
|
||||
# Remove trailing semicolon that ES SQL parser doesn't like
|
||||
sql = sql.rstrip().rstrip(';').strip()
|
||||
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
|
||||
if doc_engine != "infinity" and kb_ids:
|
||||
# Build kb_filter: single KB or multiple KBs with OR
|
||||
if len(kb_ids) == 1:
|
||||
kb_filter = f"kb_id = '{kb_ids[0]}'"
|
||||
else:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
|
||||
if "where " not in sql.lower():
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
|
||||
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
|
||||
tbl = settings.retriever.sql_retrieval(sql, format="json")
|
||||
if tbl is None:
|
||||
logging.debug("use_sql: SQL retrieval returned None")
|
||||
return None, sql
|
||||
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
|
||||
return tbl, sql
|
||||
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
|
||||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
|
||||
except Exception as e:
|
||||
user_prompt = """
|
||||
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
|
||||
# Build retry prompt with error information
|
||||
if doc_engine == "infinity":
|
||||
# Build Infinity error retry prompt
|
||||
json_field_names = list(field_map.keys())
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
|
||||
{}
|
||||
|
||||
Question: {}
|
||||
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
|
||||
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
|
||||
else:
|
||||
# Build ES/OS error retry prompt
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
Table of database fields are as follows (use the field names directly in SQL):
|
||||
{}
|
||||
|
||||
Question are as follows:
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
|
||||
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
|
||||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
|
||||
except Exception:
|
||||
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
|
||||
return
|
||||
|
||||
if len(tbl["rows"]) == 0:
|
||||
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
|
||||
|
||||
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
|
||||
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
|
||||
|
||||
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||
|
||||
logging.debug(f"use_sql: column_idx={column_idx}")
|
||||
logging.debug(f"use_sql: field_map={field_map}")
|
||||
|
||||
# Helper function to map column names to display names
|
||||
def map_column_name(col_name):
|
||||
if col_name.lower() == "count(star)":
|
||||
return "COUNT(*)"
|
||||
|
||||
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
|
||||
# Pattern: anything AS alias_name
|
||||
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
|
||||
if as_match:
|
||||
alias = as_match.group(1).strip('"\'')
|
||||
|
||||
# Use the alias for display name lookup
|
||||
if alias in field_map:
|
||||
display = field_map[alias]
|
||||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||
# If alias not in field_map, try to match case-insensitively
|
||||
for field_key, display_value in field_map.items():
|
||||
if field_key.lower() == alias.lower():
|
||||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||
# Return alias as-is if no mapping found
|
||||
return alias
|
||||
|
||||
# Try direct mapping first (for simple column names)
|
||||
if col_name in field_map:
|
||||
display = field_map[col_name]
|
||||
# Clean up any suffix patterns
|
||||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||
|
||||
# Try case-insensitive match for simple column names
|
||||
col_lower = col_name.lower()
|
||||
for field_key, display_value in field_map.items():
|
||||
if field_key.lower() == col_lower:
|
||||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||
|
||||
# For aggregate expressions or complex expressions without AS alias,
|
||||
# try to replace field names with display names
|
||||
result = col_name
|
||||
for field_name, display_name in field_map.items():
|
||||
# Replace field_name with display_name in the expression
|
||||
result = result.replace(field_name, display_name)
|
||||
|
||||
# Clean up any suffix patterns
|
||||
result = re.sub(r"(/.*|([^()]+))", "", result)
|
||||
return result
|
||||
|
||||
# compose Markdown table
|
||||
columns = (
|
||||
"|" + "|".join(
|
||||
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and docid_idx else "|")
|
||||
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and doc_name_idx else "|")
|
||||
)
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
|
||||
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||||
# Build rows ensuring column names match values - create a dict for each row
|
||||
# keyed by column name to handle any SQL column order
|
||||
rows = []
|
||||
for row_idx, r in enumerate(tbl["rows"]):
|
||||
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
|
||||
if row_idx == 0:
|
||||
logging.debug(f"use_sql: First row data: {row_dict}")
|
||||
row_values = []
|
||||
for col_idx in column_idx:
|
||||
col_name = tbl["columns"][col_idx]["name"]
|
||||
value = row_dict.get(col_name, " ")
|
||||
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
|
||||
# Add Source column with citation marker if Source column exists
|
||||
if docid_idx and doc_name_idx:
|
||||
row_values.append(f" ##{row_idx}$$")
|
||||
row_str = "|" + "|".join(row_values) + "|"
|
||||
if re.sub(r"[ |]+", "", row_str):
|
||||
rows.append(row_str)
|
||||
if quota:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = "\n".join(rows)
|
||||
else:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = "\n".join(rows)
|
||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||
|
||||
if not docid_idx or not doc_name_idx:
|
||||
logging.warning("SQL missing field: " + sql)
|
||||
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
|
||||
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
|
||||
# to provide source chunks, but keep the original table format answer
|
||||
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
|
||||
# Keep original table format as answer
|
||||
answer = "\n".join([columns, line, rows])
|
||||
|
||||
# Now fetch doc_id, docnm_kwd to provide source chunks
|
||||
# Extract WHERE clause from the original SQL
|
||||
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
|
||||
if where_match:
|
||||
where_clause = where_match.group(1).strip()
|
||||
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
|
||||
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
|
||||
# Add LIMIT to avoid fetching too many chunks
|
||||
if "limit" not in chunks_sql.lower():
|
||||
chunks_sql += " limit 20"
|
||||
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
|
||||
try:
|
||||
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
|
||||
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
|
||||
# Build chunks reference - use case-insensitive matching
|
||||
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
|
||||
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
|
||||
if chunks_did_idx is not None and chunks_dn_idx is not None:
|
||||
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
|
||||
# Build doc_aggs
|
||||
doc_aggs = {}
|
||||
for r in chunks_tbl["rows"]:
|
||||
doc_id = r[chunks_did_idx]
|
||||
doc_name = r[chunks_dn_idx]
|
||||
if doc_id not in doc_aggs:
|
||||
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
|
||||
doc_aggs[doc_id]["count"] += 1
|
||||
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
|
||||
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
|
||||
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
|
||||
except Exception as e:
|
||||
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
|
||||
# Fallback: return answer without chunks
|
||||
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
# Fallback to table format for other cases
|
||||
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
|
||||
docid_idx = list(docid_idx)[0]
|
||||
@ -690,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
if r[docid_idx] not in doc_aggs:
|
||||
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
||||
doc_aggs[r[docid_idx]]["count"] += 1
|
||||
return {
|
||||
|
||||
result = {
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {
|
||||
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||
@ -698,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
},
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
|
||||
return result
|
||||
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
|
||||
@ -1279,7 +1279,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user