Add dataset with table parser type for Infinity and answer question in chat using SQL (#12541)

### What problem does this PR solve?

1) Create  dataset using table parser for infinity
2) Answer questions in chat using SQL

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
qinling0210
2026-01-19 19:35:14 +08:00
committed by GitHub
parent 05da2a5872
commit b40d639fdb
19 changed files with 1003 additions and 101 deletions

View File

@ -53,7 +53,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y ghostscript && \
apt install -y pandoc && \
apt install -y texlive && \
apt install -y fonts-freefont-ttf fonts-noto-cjk
apt install -y fonts-freefont-ttf fonts-noto-cjk && \
apt install -y postgresql-client
# Install uv
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \

View File

@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.apps import login_required, current_user
import logging
@manager.route('/set', methods=['POST']) # noqa: F821
@ -69,6 +70,19 @@ async def set_dialog():
meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"]
# Set default parameters for datasets with knowledge retrieval
# All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
kb_ids = req.get("kb_ids", [])
parameters = prompt_config.get("parameters")
logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
# Check if parameters is missing, None, or empty list
if kb_ids and not parameters:
# Check if system prompt uses {knowledge} placeholder
if "{knowledge}" in prompt_config.get("system", ""):
# Set default parameters for any dataset with knowledge placeholder
prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
if not is_create:
# only for chat updating
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):

View File

@ -295,12 +295,19 @@ async def rm():
File.name == kbs[0].name,
]
)
# Delete the table BEFORE deleting the database record
for kb in kbs:
try:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
logging.info(f"Dropped index for dataset {kb.id}")
except Exception as e:
logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)

View File

@ -233,6 +233,15 @@ async def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
# Drop index for this dataset
try:
from rag.nlp import search
idxnm = search.index_name(kb.tenant_id)
settings.docStoreConn.delete_idx(idxnm, kb_id)
except Exception as e:
logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
if not KnowledgebaseService.delete_by_id(kb_id):
errors.append(f"Delete dataset error for {kb_id}")
continue

View File

@ -37,7 +37,6 @@ from api.db.services.tenant_llm_service import TenantLLMService
from common.time_utils import current_timestamp, datetime_format
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.advanced_rag import DeepResearcher
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
@ -274,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
async def async_chat(dialog, messages, stream=True, **kwargs):
logging.debug("Begin async_chat")
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
async for ans in async_chat_solo(dialog, messages, stream):
@ -323,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
logging.debug(f"field_map retrieved: {field_map}")
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
yield ans
return
else:
logging.debug("SQL failed or returned no results, falling back to vector search")
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -366,7 +373,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
knowledges = []
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
if attachments is not None and "knowledge" in param_keys:
logging.debug("Proceeding with retrieval")
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
if prompt_config.get("reasoning", False):
@ -575,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that:
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
2. Write only the SQL, no explanations or additional text.
"""
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
logging.debug(f"use_sql: Question: {question}")
Question are as follows:
# Determine which document engine we're using
doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es"
# Construct the full table name
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
base_table = index_name(tenant_id)
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
# Infinity: append kb_id to table name
table_name = f"{base_table}_{kb_ids[0]}"
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
else:
# Elasticsearch/OpenSearch: use base index name
table_name = base_table
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
# Generate engine-specific SQL prompts
if doc_engine == "infinity":
# Build Infinity prompts with JSON extraction context
json_field_names = list(field_map.keys())
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
RULES:
1. Use EXACT field names (case-sensitive) from the list below
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
4. Add AS alias for extracted field names
5. DO NOT select 'content' field
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
- Question asks to "show me" or "display" specific columns
- Question mentions "not null" or "excluding null"
- Add NULL check for count specific column
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
7. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Fields (EXACT case): {}
{}
Please write the SQL, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
table_name,
", ".join(json_field_names),
"\n".join([f" - {field}" for field in json_field_names]),
question
)
else:
# Build ES/OS prompts with direct field access
sys_prompt = """You are a Database Administrator. Write SQL queries.
RULES:
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
2. Quote field names starting with digit: "123_field"
3. Add IS NOT NULL in WHERE clause when:
- Question asks to "show me" or "display" specific columns
4. Include doc_id/docnm in non-aggregate statement
5. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Available fields:
{}
Question: {}
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
table_name,
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
question
)
tried_times = 0
async def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
sql = re.sub(r"&", "and", sql)
if sql[: len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
# Remove think blocks if present (format: </think>...)
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
# Remove markdown code blocks (```sql ... ```)
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
# Remove trailing semicolon that ES SQL parser doesn't like
sql = sql.rstrip().rstrip(';').strip()
if kb_ids:
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
if doc_engine != "infinity" and kb_ids:
# Build kb_filter: single KB or multiple KBs with OR
if len(kb_ids) == 1:
kb_filter = f"kb_id = '{kb_ids[0]}'"
else:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where " not in sql.lower():
o = sql.lower().split("order by")
if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else:
sql += f" WHERE {kb_filter}"
else:
sql += f" AND {kb_filter}"
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
tbl = settings.retriever.sql_retrieval(sql, format="json")
if tbl is None:
logging.debug("use_sql: SQL retrieval returned None")
return None, sql
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
return tbl, sql
try:
tbl, sql = await get_table()
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
except Exception as e:
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
# Build retry prompt with error information
if doc_engine == "infinity":
# Build Infinity error retry prompt
json_field_names = list(field_map.keys())
user_prompt = """
Table name: {};
Table of database fields are as follows:
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
Question: {}
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
The SQL error you provided last time is as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
else:
# Build ES/OS error retry prompt
user_prompt = """
Table name: {};
Table of database fields are as follows (use the field names directly in SQL):
{}
Question are as follows:
{}
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
try:
tbl, sql = await get_table()
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
except Exception:
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
return
if len(tbl["rows"]) == 0:
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
logging.debug(f"use_sql: column_idx={column_idx}")
logging.debug(f"use_sql: field_map={field_map}")
# Helper function to map column names to display names
def map_column_name(col_name):
if col_name.lower() == "count(star)":
return "COUNT(*)"
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
# Pattern: anything AS alias_name
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
if as_match:
alias = as_match.group(1).strip('"\'')
# Use the alias for display name lookup
if alias in field_map:
display = field_map[alias]
return re.sub(r"(/.*|[^]+)", "", display)
# If alias not in field_map, try to match case-insensitively
for field_key, display_value in field_map.items():
if field_key.lower() == alias.lower():
return re.sub(r"(/.*|[^]+)", "", display_value)
# Return alias as-is if no mapping found
return alias
# Try direct mapping first (for simple column names)
if col_name in field_map:
display = field_map[col_name]
# Clean up any suffix patterns
return re.sub(r"(/.*|[^]+)", "", display)
# Try case-insensitive match for simple column names
col_lower = col_name.lower()
for field_key, display_value in field_map.items():
if field_key.lower() == col_lower:
return re.sub(r"(/.*|[^]+)", "", display_value)
# For aggregate expressions or complex expressions without AS alias,
# try to replace field names with display names
result = col_name
for field_name, display_name in field_map.items():
# Replace field_name with display_name in the expression
result = result.replace(field_name, display_name)
# Clean up any suffix patterns
result = re.sub(r"(/.*|[^]+)", "", result)
return result
# compose Markdown table
columns = (
"|" + "|".join(
[re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
"|Source|" if docid_idx and docid_idx else "|")
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
"|Source|" if docid_idx and doc_name_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
# Build rows ensuring column names match values - create a dict for each row
# keyed by column name to handle any SQL column order
rows = []
for row_idx, r in enumerate(tbl["rows"]):
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
if row_idx == 0:
logging.debug(f"use_sql: First row data: {row_dict}")
row_values = []
for col_idx in column_idx:
col_name = tbl["columns"][col_idx]["name"]
value = row_dict.get(col_name, " ")
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
# Add Source column with citation marker if Source column exists
if docid_idx and doc_name_idx:
row_values.append(f" ##{row_idx}$$")
row_str = "|" + "|".join(row_values) + "|"
if re.sub(r"[ |]+", "", row_str):
rows.append(row_str)
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = "\n".join(rows)
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = "\n".join(rows)
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
# to provide source chunks, but keep the original table format answer
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
# Keep original table format as answer
answer = "\n".join([columns, line, rows])
# Now fetch doc_id, docnm_kwd to provide source chunks
# Extract WHERE clause from the original SQL
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
if where_match:
where_clause = where_match.group(1).strip()
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
# Add LIMIT to avoid fetching too many chunks
if "limit" not in chunks_sql.lower():
chunks_sql += " limit 20"
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
try:
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
# Build chunks reference - use case-insensitive matching
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
if chunks_did_idx is not None and chunks_dn_idx is not None:
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
# Build doc_aggs
doc_aggs = {}
for r in chunks_tbl["rows"]:
doc_id = r[chunks_did_idx]
doc_name = r[chunks_dn_idx]
if doc_id not in doc_aggs:
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
doc_aggs[doc_id]["count"] += 1
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
except Exception as e:
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
# Fallback: return answer without chunks
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
# Fallback to table format for other cases
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0]
@ -690,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text.
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
return {
result = {
"answer": "\n".join([columns, line, rows]),
"reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
@ -698,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text.
},
"prompt": sys_prompt,
}
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
return result
def clean_tts_text(text: str) -> str:
if not text:

View File

@ -1279,7 +1279,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -164,7 +164,7 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
"""
Create an index with given name
"""

View File

@ -123,7 +123,8 @@ class ESConnectionBase(DocStoreConnection):
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
# parser_id is used by Infinity but not needed for ES (kept for interface compatibility)
if self.index_exist(index_name, dataset_id):
return True
try:

View File

@ -228,15 +228,26 @@ class InfinityConnectionBase(DocStoreConnection):
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
table_name = f"{index_name}_{dataset_id}"
self.logger.debug(f"CREATE_IDX: Creating table {table_name}, parser_id: {parser_id}")
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
# Use configured schema
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
if parser_id is not None:
from common.constants import ParserType
if parser_id == ParserType.TABLE.value:
# Table parser: add chunk_data JSON column to store table-specific fields
schema["chunk_data"] = {"type": "json", "default": "{}"}
self.logger.info("Added chunk_data column for TABLE parser")
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
inf_table = inf_db.create_table(
@ -453,4 +464,198 @@ class InfinityConnectionBase(DocStoreConnection):
"""
def sql(self, sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")
"""
Execute SQL query on Infinity database via psql command.
Transform text-to-sql for Infinity's SQL syntax.
"""
import subprocess
try:
self.logger.debug(f"InfinityConnection.sql get sql: {sql}")
# Clean up SQL
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
# Transform SELECT field aliases to actual stored field names
# Build field mapping from infinity_mapping.json comment field
field_mapping = {}
# Also build reverse mapping for column names in result
reverse_mapping = {}
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if os.path.exists(fp_mapping):
schema = json.load(open(fp_mapping))
for field_name, field_info in schema.items():
if "comment" in field_info:
# Parse comma-separated aliases from comment
# e.g., "docnm_kwd, title_tks, title_sm_tks"
aliases = [a.strip() for a in field_info["comment"].split(",")]
for alias in aliases:
field_mapping[alias] = field_name
reverse_mapping[field_name] = alias # Store first alias for reverse mapping
# Replace field names in SELECT clause
select_match = re.search(r"(select\s+.*?)(from\s+)", sql, re.IGNORECASE)
if select_match:
select_clause = select_match.group(1)
from_clause = select_match.group(2)
# Apply field transformations
for alias, actual in field_mapping.items():
select_clause = re.sub(
rf'(^|[, ]){alias}([, ]|$)',
rf'\1{actual}\2',
select_clause
)
sql = select_clause + from_clause + sql[select_match.end():]
# Also replace field names in WHERE, ORDER BY, GROUP BY, and HAVING clauses
for alias, actual in field_mapping.items():
# Transform in WHERE clause
sql = re.sub(
rf'(\bwhere\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
# Transform in ORDER BY clause
sql = re.sub(
rf'(\border by\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
# Transform in GROUP BY clause
sql = re.sub(
rf'(\bgroup by\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
# Transform in HAVING clause
sql = re.sub(
rf'(\bhaving\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
self.logger.debug(f"InfinityConnection.sql to execute: {sql}")
# Get connection parameters from the Infinity connection pool wrapper
# We need to use INFINITY_CONN singleton, not the raw ConnectionPool
from common.doc_store.infinity_conn_pool import INFINITY_CONN
conn_info = INFINITY_CONN.get_conn_uri()
# Parse host and port from conn_info
if conn_info and "host=" in conn_info:
host_match = re.search(r"host=(\S+)", conn_info)
if host_match:
host = host_match.group(1)
else:
host = "infinity"
else:
host = "infinity"
# Parse port from conn_info, default to 5432 if not found
if conn_info and "port=" in conn_info:
port_match = re.search(r"port=(\d+)", conn_info)
if port_match:
port = port_match.group(1)
else:
port = "5432"
else:
port = "5432"
# Use psql command to execute SQL
# Use full path to psql to avoid PATH issues
psql_path = "/usr/bin/psql"
# Check if psql exists at expected location, otherwise try to find it
import shutil
psql_from_path = shutil.which("psql")
if psql_from_path:
psql_path = psql_from_path
# Execute SQL with psql to get both column names and data in one call
psql_cmd = [
psql_path,
"-h", host,
"-p", port,
"-c", sql,
]
self.logger.debug(f"Executing psql command: {' '.join(psql_cmd)}")
result = subprocess.run(
psql_cmd,
capture_output=True,
text=True,
timeout=10 # 10 second timeout
)
if result.returncode != 0:
error_msg = result.stderr.strip()
raise Exception(f"psql command failed: {error_msg}\nSQL: {sql}")
# Parse the output
output = result.stdout.strip()
if not output:
# No results
return {
"columns": [],
"rows": []
} if format == "json" else []
# Parse psql table output which has format:
# col1 | col2 | col3
# -----+-----+-----
# val1 | val2 | val3
lines = output.split("\n")
# Extract column names from first line
columns = []
rows = []
if len(lines) >= 1:
header_line = lines[0]
for col_name in header_line.split("|"):
col_name = col_name.strip()
if col_name:
columns.append({"name": col_name})
# Data starts after the separator line (line with dashes)
data_start = 2 if len(lines) >= 2 and "-" in lines[1] else 1
for i in range(data_start, len(lines)):
line = lines[i].strip()
# Skip empty lines and footer lines like "(1 row)"
if not line or re.match(r"^\(\d+ row", line):
continue
# Split by | and strip each cell
row = [cell.strip() for cell in line.split("|")]
# Ensure row matches column count
if len(row) == len(columns):
rows.append(row)
elif len(row) > len(columns):
# Row has more cells than columns - truncate
rows.append(row[:len(columns)])
elif len(row) < len(columns):
# Row has fewer cells - pad with empty strings
rows.append(row + [""] * (len(columns) - len(row)))
if format == "json":
result = {
"columns": columns,
"rows": rows[:fetch_size] if fetch_size > 0 else rows
}
else:
result = rows[:fetch_size] if fetch_size > 0 else rows
return result
except subprocess.TimeoutExpired:
self.logger.exception(f"InfinityConnection.sql timeout. SQL:\n{sql}")
raise Exception(f"SQL timeout\n\nSQL: {sql}")
except Exception as e:
self.logger.exception(f"InfinityConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")

View File

@ -31,7 +31,11 @@ class InfinityConnectionPool:
if hasattr(settings, "INFINITY"):
self.INFINITY_CONFIG = settings.INFINITY
else:
self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817"})
self.INFINITY_CONFIG = settings.get_base_config("infinity", {
"uri": "infinity:23817",
"postgres_port": 5432,
"db_name": "default_db"
})
infinity_uri = self.INFINITY_CONFIG["uri"]
if ":" in infinity_uri:
@ -61,6 +65,19 @@ class InfinityConnectionPool:
def get_conn_pool(self):
return self.conn_pool
def get_conn_uri(self):
"""
Get connection URI for PostgreSQL protocol.
"""
infinity_uri = self.INFINITY_CONFIG["uri"]
postgres_port = self.INFINITY_CONFIG["postgres_port"]
db_name = self.INFINITY_CONFIG["db_name"]
if ":" in infinity_uri:
host, _ = infinity_uri.split(":")
return f"host={host} port={postgres_port} dbname={db_name}"
return f"host=localhost port={postgres_port} dbname={db_name}"
def refresh_conn_pool(self):
try:
inf_conn = self.conn_pool.get_conn()

View File

@ -249,7 +249,11 @@ def init_settings():
ES = get_base_config("es", {})
docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
INFINITY = get_base_config("infinity", {
"uri": "infinity:23817",
"postgres_port": 5432,
"db_name": "default_db"
})
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch":
OS = get_base_config("os", {})
@ -269,7 +273,11 @@ def init_settings():
ES = get_base_config("es", {})
msgStoreConn = memory_es_conn.ESConnection()
elif DOC_ENGINE == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
INFINITY = get_base_config("infinity", {
"uri": "infinity:23817",
"postgres_port": 5432,
"db_name": "default_db"
})
msgStoreConn = memory_infinity_conn.InfinityConnection()
global AZURE, S3, MINIO, OSS, GCS

View File

@ -29,6 +29,7 @@ os:
password: 'infini_rag_flow_OS_01'
infinity:
uri: 'localhost:23817'
postgres_port: 5432
db_name: 'default_db'
oceanbase:
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config

View File

@ -29,6 +29,7 @@ os:
password: '${OPENSEARCH_PASSWORD:-infini_rag_flow_OS_01}'
infinity:
uri: '${INFINITY_HOST:-infinity}:23817'
postgres_port: 5432
db_name: 'default_db'
oceanbase:
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config

View File

@ -33,6 +33,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_figure_xlsx_wrappe
from deepdoc.parser.utils import get_text
from rag.nlp import rag_tokenizer, tokenize, tokenize_table
from deepdoc.parser import ExcelParser
from common import settings
class Excel(ExcelParser):
@ -431,7 +432,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
res = []
PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
# Field type suffixes for database columns
# Maps data types to their database field suffixes
fields_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
if n in df.columns:
@ -452,13 +455,24 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
df[clmns[j]] = cln
if ty == "text":
txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
clmns_map = [(py_clmns[i].lower() + fields_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
range(len(clmns))]
# For Infinity: Use original column names as keys since they're stored in chunk_data JSON
# For ES/OS: Use full field names with type suffixes (e.g., url_kwd, body_tks)
if settings.DOC_ENGINE_INFINITY:
# For Infinity: key = original column name, value = display name
field_map = {py_clmns[i].lower(): str(clmns[i]).replace("_", " ") for i in range(len(clmns))}
else:
# For ES/OS: key = typed field name, value = display name
field_map = {k: v for k, v in clmns_map}
logging.debug(f"Field map: {field_map}")
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map})
eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows():
d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
row_txt = []
row_fields = []
data_json = {} # For Infinity: Store all columns in a JSON object
for j in range(len(clmns)):
if row[clmns[j]] is None:
continue
@ -466,17 +480,27 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
continue
if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]):
continue
# For Infinity: Store in chunk_data JSON column
# For Elasticsearch/OpenSearch: Store as individual fields with type suffixes
if settings.DOC_ENGINE_INFINITY:
data_json[str(clmns[j])] = row[clmns[j]]
else:
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt:
row_fields.append((clmns[j], row[clmns[j]]))
if not row_fields:
continue
tokenize(d, "; ".join(row_txt), eng)
# Add the data JSON field to the document (for Infinity only)
if settings.DOC_ENGINE_INFINITY:
d["chunk_data"] = data_json
# Format as a structured text for better LLM comprehension
# Format each field as "- Field Name: Value" on separate lines
formatted_text = "\n".join([f"- {field}: {value}" for field, value in row_fields])
tokenize(d, formatted_text, eng)
res.append(d)
if tbls:
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
res.extend(tokenize_table(tbls, doc, is_english))
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.35, "")
return res

View File

@ -558,7 +558,8 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size)
parser_id = row.get("parser_id", None)
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size, parser_id)
async def embedding(docs, mdl, parser_config=None, callback=None):
@ -739,7 +740,7 @@ async def run_dataflow(task: dict):
start_ts = timer()
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
@ -833,7 +834,17 @@ async def delete_image(kb_id, chunk_id):
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
"""
Insert chunks into document store (Elasticsearch OR Infinity).
Args:
task_id: Task identifier
task_tenant_id: Tenant ID
task_dataset_id: Dataset/knowledge base ID
chunks: List of chunk dictionaries to insert
progress_callback: Callback function for progress updates
"""
mothers = []
mother_ids = set([])
for ck in chunks:
@ -858,7 +869,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
search.index_name(task_tenant_id), task_dataset_id)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -866,7 +877,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
search.index_name(task_tenant_id), task_dataset_id)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -932,13 +943,6 @@ async def do_handle_task(task):
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
lower_case_doc_engine = settings.DOC_ENGINE.lower()
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
progress_callback(-1, msg=error_message)
raise Exception(error_message)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -1092,14 +1096,14 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
async def _maybe_insert_es(_chunks):
async def _maybe_insert_chunks(_chunks):
if has_canceled(task_id):
return True
insert_result = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
insert_result = await insert_chunks(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(insert_result)
try:
if not await _maybe_insert_es(chunks):
if not await _maybe_insert_chunks(chunks):
return
logging.info(
@ -1115,7 +1119,7 @@ async def do_handle_task(task):
if toc_thread:
d = toc_thread.result()
if d:
if not await _maybe_insert_es([d]):
if not await _maybe_insert_chunks([d]):
return
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)

View File

@ -317,7 +317,18 @@ class InfinityConnection(InfinityConnectionBase):
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.create_idx(index_name, knowledgebase_id, vector_size)
# Determine parser_id from document structure
# Table parser documents have 'chunk_data' field
parser_id = None
if "chunk_data" in documents[0] and isinstance(documents[0].get("chunk_data"), dict):
from common.constants import ParserType
parser_id = ParserType.TABLE.value
self.logger.debug("Detected TABLE parser from document structure")
# Fallback: Create table with base schema (shouldn't normally happen as init_kb() creates it)
self.logger.debug(f"Fallback: Creating table {table_name} with base schema, parser_id: {parser_id}")
self.create_idx(index_name, knowledgebase_id, vector_size, parser_id)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
@ -378,6 +389,12 @@ class InfinityConnection(InfinityConnectionBase):
d[k] = v
elif re.search(r"_feas$", k):
d[k] = json.dumps(v)
elif k == "chunk_data":
# Convert data dict to JSON string for storage
if isinstance(v, dict):
d[k] = json.dumps(v)
else:
d[k] = v
elif k == "kb_id":
if isinstance(d[k], list):
d[k] = d[k][0] # since d[k] is a list, but we need a str
@ -586,6 +603,9 @@ class InfinityConnection(InfinityConnectionBase):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
elif re.search(r"_feas$", k):
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
elif k == "chunk_data":
# Parse JSON data back to dict for table parser fields
res2[column] = res2[column].apply(lambda v: json.loads(v) if v and isinstance(v, str) else v)
elif k == "position_int":
def to_position_int(v):
if v:

View File

@ -49,6 +49,11 @@ def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None
def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None):
"""
Delete datasets.
The endpoint is DELETE /api/{VERSION}/datasets with payload {"ids": [...]}
This is the standard SDK REST API endpoint for dataset deletion.
"""
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
@ -300,12 +305,6 @@ def metadata_summary(auth, dataset_id, params=None):
# CHAT COMPLETIONS AND RELATED QUESTIONS
def chat_completions(auth, chat_assistant_id, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}/completions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def related_questions(auth, payload=None):
url = f"{HOST_ADDRESS}/api/{VERSION}/sessions/related_questions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
@ -355,3 +354,23 @@ def agent_completions(auth, agent_id, payload=None):
url = f"{HOST_ADDRESS}{AGENT_API_URL}/{agent_id}/completions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def chat_completions(auth, chat_id, payload=None):
"""
Send a question/message to a chat assistant and get completion.
Args:
auth: Authentication object
chat_id: Chat assistant ID
payload: Dictionary containing:
- question: str (required) - The question to ask
- stream: bool (optional) - Whether to stream responses, default False
- session_id: str (optional) - Session ID for conversation context
Returns:
Response JSON with answer data
"""
url = f"{HOST_ADDRESS}/api/{VERSION}/chats/{chat_id}/completions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()

View File

@ -0,0 +1,42 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import create_dataset, delete_datasets
@pytest.fixture(scope="class")
def add_table_parser_dataset(HttpApiAuth, request):
"""
Fixture to create a table parser dataset for testing.
Automatically cleans up after tests complete (deletes dataset and table).
Note: field_map is automatically generated by the table parser when processing files.
"""
dataset_payload = {
"name": "test_table_parser_dataset",
"chunk_method": "table", # table parser
}
res = create_dataset(HttpApiAuth, dataset_payload)
assert res["code"] == 0, f"Failed to create dataset: {res}"
dataset_id = res["data"]["id"]
def cleanup():
delete_datasets(HttpApiAuth, {"ids": [dataset_id]})
request.addfinalizer(cleanup)
return dataset_id

View File

@ -0,0 +1,324 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import tempfile
import pytest
from common import (
chat_completions,
create_chat_assistant,
create_session_with_chat_assistant,
delete_chat_assistants,
list_documents,
upload_documents,
parse_documents,
)
from utils import wait_for
@wait_for(200, 1, "Document parsing timeout")
def wait_for_parsing_completion(auth, dataset_id, document_id=None):
"""
Wait for document parsing to complete.
Args:
auth: Authentication object
dataset_id: Dataset ID
document_id: Optional specific document ID to wait for
Returns:
bool: True if parsing is complete, False otherwise
"""
res = list_documents(auth, dataset_id)
docs = res["data"]["docs"]
if document_id is None:
# Wait for all documents to complete
for doc in docs:
status = doc.get("run", "UNKNOWN")
if status != "DONE":
print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}")
return False
return True
else:
# Wait for specific document
for doc in docs:
if doc["id"] == document_id:
status = doc.get("run", "UNKNOWN")
print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}")
if status == "DONE":
return True
elif status == "FAILED":
pytest.fail(f"Document parsing failed: {doc}")
return False
return False
# Test data
TEST_EXCEL_DATA = [
["employee_id", "name", "department", "salary"],
["E001", "Alice Johnson", "Engineering", "95000"],
["E002", "Bob Smith", "Marketing", "65000"],
["E003", "Carol Williams", "Engineering", "88000"],
["E004", "David Brown", "Sales", "72000"],
["E005", "Eva Davis", "HR", "68000"],
["E006", "Frank Miller", "Engineering", "102000"],
]
TEST_EXCEL_DATA_2 = [
["product", "price", "category"],
["Laptop", "999", "Electronics"],
["Mouse", "29", "Electronics"],
["Desk", "299", "Furniture"],
["Chair", "199", "Furniture"],
["Monitor", "399", "Electronics"],
["Keyboard", "79", "Electronics"],
]
DEFAULT_CHAT_PROMPT = (
"You are a helpful assistant that answers questions about table data using SQL queries.\n\n"
"Here is the knowledge base:\n{knowledge}\n\n"
"Use this information to answer questions."
)
@pytest.mark.usefixtures("add_table_parser_dataset")
class TestTableParserDatasetChat:
"""
Test table parser dataset chat functionality with Infinity backend.
Verifies that:
1. Excel files are uploaded and parsed correctly into table parser datasets
2. Chat assistants can query the parsed table data via SQL
3. Different types of queries work
"""
@pytest.fixture(autouse=True)
def setup_chat_assistant(self, HttpApiAuth, add_table_parser_dataset, request):
"""
Setup fixture that runs before each test method.
Creates chat assistant once and reuses it across all test cases.
"""
# Only setup once (first time)
if not hasattr(self.__class__, 'chat_id'):
self.__class__.dataset_id = add_table_parser_dataset
self.__class__.auth = HttpApiAuth
# Upload and parse Excel files once for all tests
self._upload_and_parse_excel(HttpApiAuth, add_table_parser_dataset)
# Create a single chat assistant and session for all tests
chat_id, session_id = self._create_chat_assistant_with_session(
HttpApiAuth, add_table_parser_dataset
)
self.__class__.chat_id = chat_id
self.__class__.session_id = session_id
# Store the total number of parametrize cases
mark = request.node.get_closest_marker('parametrize')
if mark:
# Get the number of test cases from parametrize
param_values = mark.args[1]
self.__class__._total_tests = len(param_values)
else:
self.__class__._total_tests = 1
yield
# Teardown: cleanup chat assistant after all tests
# Use a class-level counter to track tests
if not hasattr(self.__class__, '_test_counter'):
self.__class__._test_counter = 0
self.__class__._test_counter += 1
# Cleanup after all parametrize tests complete
if self.__class__._test_counter >= self.__class__._total_tests:
self._teardown_chat_assistant()
def _teardown_chat_assistant(self):
"""Teardown method to clean up chat assistant."""
if hasattr(self.__class__, 'chat_id') and self.__class__.chat_id:
try:
delete_chat_assistants(self.__class__.auth, {"ids": [self.__class__.chat_id]})
except Exception as e:
print(f"[Teardown] Warning: Failed to delete chat assistant: {e}")
@pytest.mark.p1
@pytest.mark.parametrize(
"question, expected_answer_pattern",
[
("show me column of product", r"\|product\|Source"),
("which product has price 79", r"Keyboard"),
("How many rows in the dataset?", r"count\(\*\)"),
("Show me all employees in Engineering department", r"(Alice|Carol|Frank)"),
],
)
def test_table_parser_dataset_chat(self, question, expected_answer_pattern):
"""
Test that table parser dataset chat works correctly.
"""
# Use class-level attributes (set by setup fixture)
answer = self._ask_question(
self.__class__.auth,
self.__class__.chat_id,
self.__class__.session_id,
question
)
# Verify answer matches expected pattern if provided
if expected_answer_pattern:
self._assert_answer_matches_pattern(answer, expected_answer_pattern)
else:
# Just verify we got a non-empty answer
assert answer and len(answer) > 0, "Expected non-empty answer"
print(f"[Test] Question: {question}")
print(f"[Test] Answer: {answer[:100]}...")
@staticmethod
def _upload_and_parse_excel(auth, dataset_id):
"""
Upload 2 Excel files and wait for parsing to complete.
Returns:
list: The document IDs of the uploaded files
Raises:
AssertionError: If upload or parsing fails
"""
excel_file_paths = []
document_ids = []
try:
# Create 2 temporary Excel files
excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA))
excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA_2))
# Upload documents
res = upload_documents(auth, dataset_id, excel_file_paths)
assert res["code"] == 0, f"Failed to upload documents: {res}"
for doc in res["data"]:
document_ids.append(doc["id"])
# Start parsing for all documents
parse_payload = {"document_ids": document_ids}
res = parse_documents(auth, dataset_id, parse_payload)
assert res["code"] == 0, f"Failed to start parsing: {res}"
# Wait for parsing completion for all documents
for doc_id in document_ids:
wait_for_parsing_completion(auth, dataset_id, doc_id)
return document_ids
finally:
# Clean up temporary files
for excel_file_path in excel_file_paths:
if excel_file_path:
os.unlink(excel_file_path)
@staticmethod
def _create_temp_excel_file(data):
"""
Create a temporary Excel file with the given table test data.
Args:
data: List of lists containing the Excel data
Returns:
str: Path to the created temporary file
"""
from openpyxl import Workbook
f = tempfile.NamedTemporaryFile(mode="wb", suffix=".xlsx", delete=False)
f.close()
wb = Workbook()
ws = wb.active
# Write test data to the worksheet
for row_idx, row_data in enumerate(data, start=1):
for col_idx, value in enumerate(row_data, start=1):
ws.cell(row=row_idx, column=col_idx, value=value)
wb.save(f.name)
return f.name
@staticmethod
def _create_chat_assistant_with_session(auth, dataset_id):
"""
Create a chat assistant and session for testing.
Returns:
tuple: (chat_id, session_id)
"""
import uuid
chat_payload = {
"name": f"test_table_parser_dataset_chat_{uuid.uuid4().hex[:8]}",
"dataset_ids": [dataset_id],
"prompt_config": {
"system": DEFAULT_CHAT_PROMPT,
"parameters": [
{
"key": "knowledge",
"optional": True,
"value": "Use the table data to answer questions with SQL queries.",
}
],
},
}
res = create_chat_assistant(auth, chat_payload)
assert res["code"] == 0, f"Failed to create chat assistant: {res}"
chat_id = res["data"]["id"]
res = create_session_with_chat_assistant(auth, chat_id, {"name": f"test_session_{uuid.uuid4().hex[:8]}"})
assert res["code"] == 0, f"Failed to create session: {res}"
session_id = res["data"]["id"]
return chat_id, session_id
def _ask_question(self, auth, chat_id, session_id, question):
"""
Send a question to the chat assistant and return the answer.
Returns:
str: The assistant's answer
"""
payload = {
"question": question,
"stream": False,
"session_id": session_id,
}
res_json = chat_completions(auth, chat_id, payload)
assert res_json["code"] == 0, f"Chat completion failed: {res_json}"
return res_json["data"]["answer"]
def _assert_answer_matches_pattern(self, answer, pattern):
"""
Assert that the answer matches the expected pattern.
Args:
answer: The actual answer from the chat assistant
pattern: Regular expression pattern to match
"""
assert re.search(pattern, answer, re.IGNORECASE), (
f"Answer does not match expected pattern '{pattern}'.\n"
f"Answer: {answer}"
)