Add dataset with table parser type for Infinity and answer question in chat using SQL (#12541)

### What problem does this PR solve?

1) Create  dataset using table parser for infinity
2) Answer questions in chat using SQL

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
qinling0210
2026-01-19 19:35:14 +08:00
committed by GitHub
parent 05da2a5872
commit b40d639fdb
19 changed files with 1003 additions and 101 deletions

View File

@ -53,7 +53,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt install -y ghostscript && \
apt install -y pandoc && \
apt install -y texlive && \
apt install -y fonts-freefont-ttf fonts-noto-cjk
apt install -y fonts-freefont-ttf fonts-noto-cjk && \
apt install -y postgresql-client
# Install uv
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \

View File

@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ
from common.misc_utils import get_uuid
from common.constants import RetCode
from api.apps import login_required, current_user
import logging
@manager.route('/set', methods=['POST']) # noqa: F821
@ -69,6 +70,19 @@ async def set_dialog():
meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"]
# Set default parameters for datasets with knowledge retrieval
# All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
kb_ids = req.get("kb_ids", [])
parameters = prompt_config.get("parameters")
logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
# Check if parameters is missing, None, or empty list
if kb_ids and not parameters:
# Check if system prompt uses {knowledge} placeholder
if "{knowledge}" in prompt_config.get("system", ""):
# Set default parameters for any dataset with knowledge placeholder
prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
if not is_create:
# only for chat updating
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):

View File

@ -295,12 +295,19 @@ async def rm():
File.name == kbs[0].name,
]
)
# Delete the table BEFORE deleting the database record
for kb in kbs:
try:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
logging.info(f"Dropped index for dataset {kb.id}")
except Exception as e:
logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)

View File

@ -233,6 +233,15 @@ async def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
# Drop index for this dataset
try:
from rag.nlp import search
idxnm = search.index_name(kb.tenant_id)
settings.docStoreConn.delete_idx(idxnm, kb_id)
except Exception as e:
logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
if not KnowledgebaseService.delete_by_id(kb_id):
errors.append(f"Delete dataset error for {kb_id}")
continue

View File

@ -37,7 +37,6 @@ from api.db.services.tenant_llm_service import TenantLLMService
from common.time_utils import current_timestamp, datetime_format
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.advanced_rag import DeepResearcher
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
@ -274,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
async def async_chat(dialog, messages, stream=True, **kwargs):
logging.debug("Begin async_chat")
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
async for ans in async_chat_solo(dialog, messages, stream):
@ -323,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
logging.debug(f"field_map retrieved: {field_map}")
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans:
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
yield ans
return
else:
logging.debug("SQL failed or returned no results, falling back to vector search")
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
@ -366,7 +373,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
knowledges = []
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
if attachments is not None and "knowledge" in param_keys:
logging.debug("Proceeding with retrieval")
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
if prompt_config.get("reasoning", False):
@ -575,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that:
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
2. Write only the SQL, no explanations or additional text.
"""
user_prompt = """
Table name: {};
Table of database fields are as follows:
{}
logging.debug(f"use_sql: Question: {question}")
Question are as follows:
# Determine which document engine we're using
doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es"
# Construct the full table name
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
base_table = index_name(tenant_id)
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
# Infinity: append kb_id to table name
table_name = f"{base_table}_{kb_ids[0]}"
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
else:
# Elasticsearch/OpenSearch: use base index name
table_name = base_table
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
# Generate engine-specific SQL prompts
if doc_engine == "infinity":
# Build Infinity prompts with JSON extraction context
json_field_names = list(field_map.keys())
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
RULES:
1. Use EXACT field names (case-sensitive) from the list below
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
4. Add AS alias for extracted field names
5. DO NOT select 'content' field
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
- Question asks to "show me" or "display" specific columns
- Question mentions "not null" or "excluding null"
- Add NULL check for count specific column
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
7. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Fields (EXACT case): {}
{}
Please write the SQL, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
table_name,
", ".join(json_field_names),
"\n".join([f" - {field}" for field in json_field_names]),
question
)
else:
# Build ES/OS prompts with direct field access
sys_prompt = """You are a Database Administrator. Write SQL queries.
RULES:
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
2. Quote field names starting with digit: "123_field"
3. Add IS NOT NULL in WHERE clause when:
- Question asks to "show me" or "display" specific columns
4. Include doc_id/docnm in non-aggregate statement
5. Output ONLY the SQL, no explanations"""
user_prompt = """Table: {}
Available fields:
{}
Question: {}
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
table_name,
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
question
)
tried_times = 0
async def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;]|```).*", "", sql)
sql = re.sub(r"&", "and", sql)
if sql[: len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[: len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
# Remove think blocks if present (format: </think>...)
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
# Remove markdown code blocks (```sql ... ```)
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
# Remove trailing semicolon that ES SQL parser doesn't like
sql = sql.rstrip().rstrip(';').strip()
if kb_ids:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where" not in sql.lower():
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
if doc_engine != "infinity" and kb_ids:
# Build kb_filter: single KB or multiple KBs with OR
if len(kb_ids) == 1:
kb_filter = f"kb_id = '{kb_ids[0]}'"
else:
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
if "where " not in sql.lower():
o = sql.lower().split("order by")
if len(o) > 1:
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
else:
sql += f" WHERE {kb_filter}"
else:
sql += f" AND {kb_filter}"
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
tbl = settings.retriever.sql_retrieval(sql, format="json")
if tbl is None:
logging.debug("use_sql: SQL retrieval returned None")
return None, sql
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
return tbl, sql
try:
tbl, sql = await get_table()
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
except Exception as e:
user_prompt = """
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
# Build retry prompt with error information
if doc_engine == "infinity":
# Build Infinity error retry prompt
json_field_names = list(field_map.keys())
user_prompt = """
Table name: {};
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
{}
Question: {}
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
The SQL error you provided last time is as follows:
{}
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
else:
# Build ES/OS error retry prompt
user_prompt = """
Table name: {};
Table of database fields are as follows:
Table of database fields are as follows (use the field names directly in SQL):
{}
Question are as follows:
{}
Please write the SQL, only SQL, without any other explanations or text.
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
The SQL error you provided last time is as follows:
{}
Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
try:
tbl, sql = await get_table()
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
except Exception:
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
return
if len(tbl["rows"]) == 0:
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
return None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
logging.debug(f"use_sql: column_idx={column_idx}")
logging.debug(f"use_sql: field_map={field_map}")
# Helper function to map column names to display names
def map_column_name(col_name):
if col_name.lower() == "count(star)":
return "COUNT(*)"
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
# Pattern: anything AS alias_name
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
if as_match:
alias = as_match.group(1).strip('"\'')
# Use the alias for display name lookup
if alias in field_map:
display = field_map[alias]
return re.sub(r"(/.*|[^]+)", "", display)
# If alias not in field_map, try to match case-insensitively
for field_key, display_value in field_map.items():
if field_key.lower() == alias.lower():
return re.sub(r"(/.*|[^]+)", "", display_value)
# Return alias as-is if no mapping found
return alias
# Try direct mapping first (for simple column names)
if col_name in field_map:
display = field_map[col_name]
# Clean up any suffix patterns
return re.sub(r"(/.*|[^]+)", "", display)
# Try case-insensitive match for simple column names
col_lower = col_name.lower()
for field_key, display_value in field_map.items():
if field_key.lower() == col_lower:
return re.sub(r"(/.*|[^]+)", "", display_value)
# For aggregate expressions or complex expressions without AS alias,
# try to replace field names with display names
result = col_name
for field_name, display_name in field_map.items():
# Replace field_name with display_name in the expression
result = result.replace(field_name, display_name)
# Clean up any suffix patterns
result = re.sub(r"(/.*|[^]+)", "", result)
return result
# compose Markdown table
columns = (
"|" + "|".join(
[re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
"|Source|" if docid_idx and docid_idx else "|")
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
"|Source|" if docid_idx and doc_name_idx else "|")
)
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
# Build rows ensuring column names match values - create a dict for each row
# keyed by column name to handle any SQL column order
rows = []
for row_idx, r in enumerate(tbl["rows"]):
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
if row_idx == 0:
logging.debug(f"use_sql: First row data: {row_dict}")
row_values = []
for col_idx in column_idx:
col_name = tbl["columns"][col_idx]["name"]
value = row_dict.get(col_name, " ")
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
# Add Source column with citation marker if Source column exists
if docid_idx and doc_name_idx:
row_values.append(f" ##{row_idx}$$")
row_str = "|" + "|".join(row_values) + "|"
if re.sub(r"[ |]+", "", row_str):
rows.append(row_str)
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = "\n".join(rows)
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = "\n".join(rows)
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not doc_name_idx:
logging.warning("SQL missing field: " + sql)
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
# to provide source chunks, but keep the original table format answer
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
# Keep original table format as answer
answer = "\n".join([columns, line, rows])
# Now fetch doc_id, docnm_kwd to provide source chunks
# Extract WHERE clause from the original SQL
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
if where_match:
where_clause = where_match.group(1).strip()
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
# Add LIMIT to avoid fetching too many chunks
if "limit" not in chunks_sql.lower():
chunks_sql += " limit 20"
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
try:
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
# Build chunks reference - use case-insensitive matching
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
if chunks_did_idx is not None and chunks_dn_idx is not None:
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
# Build doc_aggs
doc_aggs = {}
for r in chunks_tbl["rows"]:
doc_id = r[chunks_did_idx]
doc_name = r[chunks_dn_idx]
if doc_id not in doc_aggs:
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
doc_aggs[doc_id]["count"] += 1
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
except Exception as e:
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
# Fallback: return answer without chunks
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
# Fallback to table format for other cases
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
docid_idx = list(docid_idx)[0]
@ -690,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text.
if r[docid_idx] not in doc_aggs:
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
doc_aggs[r[docid_idx]]["count"] += 1
return {
result = {
"answer": "\n".join([columns, line, rows]),
"reference": {
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
@ -698,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text.
},
"prompt": sys_prompt,
}
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
return result
def clean_tts_text(text: str) -> str:
if not text:

View File

@ -1279,7 +1279,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -164,7 +164,7 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
"""
Create an index with given name
"""

View File

@ -123,7 +123,8 @@ class ESConnectionBase(DocStoreConnection):
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
# parser_id is used by Infinity but not needed for ES (kept for interface compatibility)
if self.index_exist(index_name, dataset_id):
return True
try:

View File

@ -228,15 +228,26 @@ class InfinityConnectionBase(DocStoreConnection):
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
table_name = f"{index_name}_{dataset_id}"
self.logger.debug(f"CREATE_IDX: Creating table {table_name}, parser_id: {parser_id}")
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
# Use configured schema
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
if parser_id is not None:
from common.constants import ParserType
if parser_id == ParserType.TABLE.value:
# Table parser: add chunk_data JSON column to store table-specific fields
schema["chunk_data"] = {"type": "json", "default": "{}"}
self.logger.info("Added chunk_data column for TABLE parser")
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
inf_table = inf_db.create_table(
@ -453,4 +464,198 @@ class InfinityConnectionBase(DocStoreConnection):
"""
def sql(self, sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")
"""
Execute SQL query on Infinity database via psql command.
Transform text-to-sql for Infinity's SQL syntax.
"""
import subprocess
try:
self.logger.debug(f"InfinityConnection.sql get sql: {sql}")
# Clean up SQL
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
# Transform SELECT field aliases to actual stored field names
# Build field mapping from infinity_mapping.json comment field
field_mapping = {}
# Also build reverse mapping for column names in result
reverse_mapping = {}
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if os.path.exists(fp_mapping):
schema = json.load(open(fp_mapping))
for field_name, field_info in schema.items():
if "comment" in field_info:
# Parse comma-separated aliases from comment
# e.g., "docnm_kwd, title_tks, title_sm_tks"
aliases = [a.strip() for a in field_info["comment"].split(",")]
for alias in aliases:
field_mapping[alias] = field_name
reverse_mapping[field_name] = alias # Store first alias for reverse mapping
# Replace field names in SELECT clause
select_match = re.search(r"(select\s+.*?)(from\s+)", sql, re.IGNORECASE)
if select_match:
select_clause = select_match.group(1)
from_clause = select_match.group(2)
# Apply field transformations
for alias, actual in field_mapping.items():
select_clause = re.sub(
rf'(^|[, ]){alias}([, ]|$)',
rf'\1{actual}\2',
select_clause
)
sql = select_clause + from_clause + sql[select_match.end():]
# Also replace field names in WHERE, ORDER BY, GROUP BY, and HAVING clauses
for alias, actual in field_mapping.items():
# Transform in WHERE clause
sql = re.sub(
rf'(\bwhere\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
# Transform in ORDER BY clause
sql = re.sub(
rf'(\border by\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
# Transform in GROUP BY clause
sql = re.sub(
rf'(\bgroup by\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
# Transform in HAVING clause
sql = re.sub(
rf'(\bhaving\s+[^;]*?)(\b){re.escape(alias)}\b',
rf'\1{actual}',
sql,
flags=re.IGNORECASE
)
self.logger.debug(f"InfinityConnection.sql to execute: {sql}")
# Get connection parameters from the Infinity connection pool wrapper
# We need to use INFINITY_CONN singleton, not the raw ConnectionPool
from common.doc_store.infinity_conn_pool import INFINITY_CONN
conn_info = INFINITY_CONN.get_conn_uri()
# Parse host and port from conn_info
if conn_info and "host=" in conn_info:
host_match = re.search(r"host=(\S+)", conn_info)
if host_match:
host = host_match.group(1)
else:
host = "infinity"
else:
host = "infinity"
# Parse port from conn_info, default to 5432 if not found
if conn_info and "port=" in conn_info:
port_match = re.search(r"port=(\d+)", conn_info)
if port_match:
port = port_match.group(1)
else:
port = "5432"
else:
port = "5432"
# Use psql command to execute SQL
# Use full path to psql to avoid PATH issues
psql_path = "/usr/bin/psql"
# Check if psql exists at expected location, otherwise try to find it
import shutil
psql_from_path = shutil.which("psql")
if psql_from_path:
psql_path = psql_from_path
# Execute SQL with psql to get both column names and data in one call
psql_cmd = [
psql_path,
"-h", host,
"-p", port,
"-c", sql,
]
self.logger.debug(f"Executing psql command: {' '.join(psql_cmd)}")
result = subprocess.run(
psql_cmd,
capture_output=True,
text=True,
timeout=10 # 10 second timeout
)
if result.returncode != 0:
error_msg = result.stderr.strip()
raise Exception(f"psql command failed: {error_msg}\nSQL: {sql}")
# Parse the output
output = result.stdout.strip()
if not output:
# No results
return {
"columns": [],
"rows": []
} if format == "json" else []
# Parse psql table output which has format:
# col1 | col2 | col3
# -----+-----+-----
# val1 | val2 | val3
lines = output.split("\n")
# Extract column names from first line
columns = []
rows = []
if len(lines) >= 1:
header_line = lines[0]
for col_name in header_line.split("|"):
col_name = col_name.strip()
if col_name:
columns.append({"name": col_name})
# Data starts after the separator line (line with dashes)
data_start = 2 if len(lines) >= 2 and "-" in lines[1] else 1
for i in range(data_start, len(lines)):
line = lines[i].strip()
# Skip empty lines and footer lines like "(1 row)"
if not line or re.match(r"^\(\d+ row", line):
continue
# Split by | and strip each cell
row = [cell.strip() for cell in line.split("|")]
# Ensure row matches column count
if len(row) == len(columns):
rows.append(row)
elif len(row) > len(columns):
# Row has more cells than columns - truncate
rows.append(row[:len(columns)])
elif len(row) < len(columns):
# Row has fewer cells - pad with empty strings
rows.append(row + [""] * (len(columns) - len(row)))
if format == "json":
result = {
"columns": columns,
"rows": rows[:fetch_size] if fetch_size > 0 else rows
}
else:
result = rows[:fetch_size] if fetch_size > 0 else rows
return result
except subprocess.TimeoutExpired:
self.logger.exception(f"InfinityConnection.sql timeout. SQL:\n{sql}")
raise Exception(f"SQL timeout\n\nSQL: {sql}")
except Exception as e:
self.logger.exception(f"InfinityConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")

View File

@ -31,7 +31,11 @@ class InfinityConnectionPool:
if hasattr(settings, "INFINITY"):
self.INFINITY_CONFIG = settings.INFINITY
else:
self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817"})
self.INFINITY_CONFIG = settings.get_base_config("infinity", {
"uri": "infinity:23817",
"postgres_port": 5432,
"db_name": "default_db"
})
infinity_uri = self.INFINITY_CONFIG["uri"]
if ":" in infinity_uri:
@ -61,6 +65,19 @@ class InfinityConnectionPool:
def get_conn_pool(self):
return self.conn_pool
def get_conn_uri(self):
"""
Get connection URI for PostgreSQL protocol.
"""
infinity_uri = self.INFINITY_CONFIG["uri"]
postgres_port = self.INFINITY_CONFIG["postgres_port"]
db_name = self.INFINITY_CONFIG["db_name"]
if ":" in infinity_uri:
host, _ = infinity_uri.split(":")
return f"host={host} port={postgres_port} dbname={db_name}"
return f"host=localhost port={postgres_port} dbname={db_name}"
def refresh_conn_pool(self):
try:
inf_conn = self.conn_pool.get_conn()

View File

@ -249,7 +249,11 @@ def init_settings():
ES = get_base_config("es", {})
docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
INFINITY = get_base_config("infinity", {
"uri": "infinity:23817",
"postgres_port": 5432,
"db_name": "default_db"
})
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch":
OS = get_base_config("os", {})
@ -269,7 +273,11 @@ def init_settings():
ES = get_base_config("es", {})
msgStoreConn = memory_es_conn.ESConnection()
elif DOC_ENGINE == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
INFINITY = get_base_config("infinity", {
"uri": "infinity:23817",
"postgres_port": 5432,
"db_name": "default_db"
})
msgStoreConn = memory_infinity_conn.InfinityConnection()
global AZURE, S3, MINIO, OSS, GCS

View File

@ -29,6 +29,7 @@ os:
password: 'infini_rag_flow_OS_01'
infinity:
uri: 'localhost:23817'
postgres_port: 5432
db_name: 'default_db'
oceanbase:
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config

View File

@ -29,6 +29,7 @@ os:
password: '${OPENSEARCH_PASSWORD:-infini_rag_flow_OS_01}'
infinity:
uri: '${INFINITY_HOST:-infinity}:23817'
postgres_port: 5432
db_name: 'default_db'
oceanbase:
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config

View File

@ -33,6 +33,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_figure_xlsx_wrappe
from deepdoc.parser.utils import get_text
from rag.nlp import rag_tokenizer, tokenize, tokenize_table
from deepdoc.parser import ExcelParser
from common import settings
class Excel(ExcelParser):
@ -431,7 +432,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
res = []
PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
# Field type suffixes for database columns
# Maps data types to their database field suffixes
fields_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
if n in df.columns:
@ -452,13 +455,24 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
df[clmns[j]] = cln
if ty == "text":
txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
clmns_map = [(py_clmns[i].lower() + fields_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
range(len(clmns))]
# For Infinity: Use original column names as keys since they're stored in chunk_data JSON
# For ES/OS: Use full field names with type suffixes (e.g., url_kwd, body_tks)
if settings.DOC_ENGINE_INFINITY:
# For Infinity: key = original column name, value = display name
field_map = {py_clmns[i].lower(): str(clmns[i]).replace("_", " ") for i in range(len(clmns))}
else:
# For ES/OS: key = typed field name, value = display name
field_map = {k: v for k, v in clmns_map}
logging.debug(f"Field map: {field_map}")
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map})
eng = lang.lower() == "english" # is_english(txts)
for ii, row in df.iterrows():
d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
row_txt = []
row_fields = []
data_json = {} # For Infinity: Store all columns in a JSON object
for j in range(len(clmns)):
if row[clmns[j]] is None:
continue
@ -466,17 +480,27 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
continue
if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]):
continue
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt:
# For Infinity: Store in chunk_data JSON column
# For Elasticsearch/OpenSearch: Store as individual fields with type suffixes
if settings.DOC_ENGINE_INFINITY:
data_json[str(clmns[j])] = row[clmns[j]]
else:
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]])
row_fields.append((clmns[j], row[clmns[j]]))
if not row_fields:
continue
tokenize(d, "; ".join(row_txt), eng)
# Add the data JSON field to the document (for Infinity only)
if settings.DOC_ENGINE_INFINITY:
d["chunk_data"] = data_json
# Format as a structured text for better LLM comprehension
# Format each field as "- Field Name: Value" on separate lines
formatted_text = "\n".join([f"- {field}: {value}" for field, value in row_fields])
tokenize(d, formatted_text, eng)
res.append(d)
if tbls:
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
res.extend(tokenize_table(tbls, doc, is_english))
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.35, "")
return res

View File

@ -558,7 +558,8 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size)
parser_id = row.get("parser_id", None)
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size, parser_id)
async def embedding(docs, mdl, parser_config=None, callback=None):
@ -739,7 +740,7 @@ async def run_dataflow(task: dict):
start_ts = timer()
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
@ -833,7 +834,17 @@ async def delete_image(kb_id, chunk_id):
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
"""
Insert chunks into document store (Elasticsearch OR Infinity).
Args:
task_id: Task identifier
task_tenant_id: Tenant ID
task_dataset_id: Dataset/knowledge base ID
chunks: List of chunk dictionaries to insert
progress_callback: Callback function for progress updates
"""
mothers = []
mother_ids = set([])
for ck in chunks:
@ -858,7 +869,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
search.index_name(task_tenant_id), task_dataset_id)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -866,7 +877,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
search.index_name(task_tenant_id), task_dataset_id)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -932,13 +943,6 @@ async def do_handle_task(task):
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
lower_case_doc_engine = settings.DOC_ENGINE.lower()
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
progress_callback(-1, msg=error_message)
raise Exception(error_message)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -1092,14 +1096,14 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
async def _maybe_insert_es(_chunks):
async def _maybe_insert_chunks(_chunks):
if has_canceled(task_id):
return True
insert_result = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
insert_result = await insert_chunks(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(insert_result)
try:
if not await _maybe_insert_es(chunks):
if not await _maybe_insert_chunks(chunks):
return
logging.info(
@ -1115,7 +1119,7 @@ async def do_handle_task(task):
if toc_thread:
d = toc_thread.result()
if d:
if not await _maybe_insert_es([d]):
if not await _maybe_insert_chunks([d]):
return
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)

View File

@ -317,7 +317,18 @@ class InfinityConnection(InfinityConnectionBase):
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.create_idx(index_name, knowledgebase_id, vector_size)
# Determine parser_id from document structure
# Table parser documents have 'chunk_data' field
parser_id = None
if "chunk_data" in documents[0] and isinstance(documents[0].get("chunk_data"), dict):
from common.constants import ParserType
parser_id = ParserType.TABLE.value
self.logger.debug("Detected TABLE parser from document structure")
# Fallback: Create table with base schema (shouldn't normally happen as init_kb() creates it)
self.logger.debug(f"Fallback: Creating table {table_name} with base schema, parser_id: {parser_id}")
self.create_idx(index_name, knowledgebase_id, vector_size, parser_id)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
@ -378,6 +389,12 @@ class InfinityConnection(InfinityConnectionBase):
d[k] = v
elif re.search(r"_feas$", k):
d[k] = json.dumps(v)
elif k == "chunk_data":
# Convert data dict to JSON string for storage
if isinstance(v, dict):
d[k] = json.dumps(v)
else:
d[k] = v
elif k == "kb_id":
if isinstance(d[k], list):
d[k] = d[k][0] # since d[k] is a list, but we need a str
@ -586,6 +603,9 @@ class InfinityConnection(InfinityConnectionBase):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
elif re.search(r"_feas$", k):
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
elif k == "chunk_data":
# Parse JSON data back to dict for table parser fields
res2[column] = res2[column].apply(lambda v: json.loads(v) if v and isinstance(v, str) else v)
elif k == "position_int":
def to_position_int(v):
if v:

View File

@ -49,6 +49,11 @@ def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None
def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None):
"""
Delete datasets.
The endpoint is DELETE /api/{VERSION}/datasets with payload {"ids": [...]}
This is the standard SDK REST API endpoint for dataset deletion.
"""
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
@ -300,12 +305,6 @@ def metadata_summary(auth, dataset_id, params=None):
# CHAT COMPLETIONS AND RELATED QUESTIONS
def chat_completions(auth, chat_assistant_id, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}/completions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def related_questions(auth, payload=None):
url = f"{HOST_ADDRESS}/api/{VERSION}/sessions/related_questions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
@ -355,3 +354,23 @@ def agent_completions(auth, agent_id, payload=None):
url = f"{HOST_ADDRESS}{AGENT_API_URL}/{agent_id}/completions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def chat_completions(auth, chat_id, payload=None):
"""
Send a question/message to a chat assistant and get completion.
Args:
auth: Authentication object
chat_id: Chat assistant ID
payload: Dictionary containing:
- question: str (required) - The question to ask
- stream: bool (optional) - Whether to stream responses, default False
- session_id: str (optional) - Session ID for conversation context
Returns:
Response JSON with answer data
"""
url = f"{HOST_ADDRESS}/api/{VERSION}/chats/{chat_id}/completions"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()

View File

@ -0,0 +1,42 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import create_dataset, delete_datasets
@pytest.fixture(scope="class")
def add_table_parser_dataset(HttpApiAuth, request):
"""
Fixture to create a table parser dataset for testing.
Automatically cleans up after tests complete (deletes dataset and table).
Note: field_map is automatically generated by the table parser when processing files.
"""
dataset_payload = {
"name": "test_table_parser_dataset",
"chunk_method": "table", # table parser
}
res = create_dataset(HttpApiAuth, dataset_payload)
assert res["code"] == 0, f"Failed to create dataset: {res}"
dataset_id = res["data"]["id"]
def cleanup():
delete_datasets(HttpApiAuth, {"ids": [dataset_id]})
request.addfinalizer(cleanup)
return dataset_id

View File

@ -0,0 +1,324 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import tempfile
import pytest
from common import (
chat_completions,
create_chat_assistant,
create_session_with_chat_assistant,
delete_chat_assistants,
list_documents,
upload_documents,
parse_documents,
)
from utils import wait_for
@wait_for(200, 1, "Document parsing timeout")
def wait_for_parsing_completion(auth, dataset_id, document_id=None):
"""
Wait for document parsing to complete.
Args:
auth: Authentication object
dataset_id: Dataset ID
document_id: Optional specific document ID to wait for
Returns:
bool: True if parsing is complete, False otherwise
"""
res = list_documents(auth, dataset_id)
docs = res["data"]["docs"]
if document_id is None:
# Wait for all documents to complete
for doc in docs:
status = doc.get("run", "UNKNOWN")
if status != "DONE":
print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}")
return False
return True
else:
# Wait for specific document
for doc in docs:
if doc["id"] == document_id:
status = doc.get("run", "UNKNOWN")
print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}")
if status == "DONE":
return True
elif status == "FAILED":
pytest.fail(f"Document parsing failed: {doc}")
return False
return False
# Test data
TEST_EXCEL_DATA = [
["employee_id", "name", "department", "salary"],
["E001", "Alice Johnson", "Engineering", "95000"],
["E002", "Bob Smith", "Marketing", "65000"],
["E003", "Carol Williams", "Engineering", "88000"],
["E004", "David Brown", "Sales", "72000"],
["E005", "Eva Davis", "HR", "68000"],
["E006", "Frank Miller", "Engineering", "102000"],
]
TEST_EXCEL_DATA_2 = [
["product", "price", "category"],
["Laptop", "999", "Electronics"],
["Mouse", "29", "Electronics"],
["Desk", "299", "Furniture"],
["Chair", "199", "Furniture"],
["Monitor", "399", "Electronics"],
["Keyboard", "79", "Electronics"],
]
DEFAULT_CHAT_PROMPT = (
"You are a helpful assistant that answers questions about table data using SQL queries.\n\n"
"Here is the knowledge base:\n{knowledge}\n\n"
"Use this information to answer questions."
)
@pytest.mark.usefixtures("add_table_parser_dataset")
class TestTableParserDatasetChat:
"""
Test table parser dataset chat functionality with Infinity backend.
Verifies that:
1. Excel files are uploaded and parsed correctly into table parser datasets
2. Chat assistants can query the parsed table data via SQL
3. Different types of queries work
"""
@pytest.fixture(autouse=True)
def setup_chat_assistant(self, HttpApiAuth, add_table_parser_dataset, request):
"""
Setup fixture that runs before each test method.
Creates chat assistant once and reuses it across all test cases.
"""
# Only setup once (first time)
if not hasattr(self.__class__, 'chat_id'):
self.__class__.dataset_id = add_table_parser_dataset
self.__class__.auth = HttpApiAuth
# Upload and parse Excel files once for all tests
self._upload_and_parse_excel(HttpApiAuth, add_table_parser_dataset)
# Create a single chat assistant and session for all tests
chat_id, session_id = self._create_chat_assistant_with_session(
HttpApiAuth, add_table_parser_dataset
)
self.__class__.chat_id = chat_id
self.__class__.session_id = session_id
# Store the total number of parametrize cases
mark = request.node.get_closest_marker('parametrize')
if mark:
# Get the number of test cases from parametrize
param_values = mark.args[1]
self.__class__._total_tests = len(param_values)
else:
self.__class__._total_tests = 1
yield
# Teardown: cleanup chat assistant after all tests
# Use a class-level counter to track tests
if not hasattr(self.__class__, '_test_counter'):
self.__class__._test_counter = 0
self.__class__._test_counter += 1
# Cleanup after all parametrize tests complete
if self.__class__._test_counter >= self.__class__._total_tests:
self._teardown_chat_assistant()
def _teardown_chat_assistant(self):
"""Teardown method to clean up chat assistant."""
if hasattr(self.__class__, 'chat_id') and self.__class__.chat_id:
try:
delete_chat_assistants(self.__class__.auth, {"ids": [self.__class__.chat_id]})
except Exception as e:
print(f"[Teardown] Warning: Failed to delete chat assistant: {e}")
@pytest.mark.p1
@pytest.mark.parametrize(
"question, expected_answer_pattern",
[
("show me column of product", r"\|product\|Source"),
("which product has price 79", r"Keyboard"),
("How many rows in the dataset?", r"count\(\*\)"),
("Show me all employees in Engineering department", r"(Alice|Carol|Frank)"),
],
)
def test_table_parser_dataset_chat(self, question, expected_answer_pattern):
"""
Test that table parser dataset chat works correctly.
"""
# Use class-level attributes (set by setup fixture)
answer = self._ask_question(
self.__class__.auth,
self.__class__.chat_id,
self.__class__.session_id,
question
)
# Verify answer matches expected pattern if provided
if expected_answer_pattern:
self._assert_answer_matches_pattern(answer, expected_answer_pattern)
else:
# Just verify we got a non-empty answer
assert answer and len(answer) > 0, "Expected non-empty answer"
print(f"[Test] Question: {question}")
print(f"[Test] Answer: {answer[:100]}...")
@staticmethod
def _upload_and_parse_excel(auth, dataset_id):
"""
Upload 2 Excel files and wait for parsing to complete.
Returns:
list: The document IDs of the uploaded files
Raises:
AssertionError: If upload or parsing fails
"""
excel_file_paths = []
document_ids = []
try:
# Create 2 temporary Excel files
excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA))
excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA_2))
# Upload documents
res = upload_documents(auth, dataset_id, excel_file_paths)
assert res["code"] == 0, f"Failed to upload documents: {res}"
for doc in res["data"]:
document_ids.append(doc["id"])
# Start parsing for all documents
parse_payload = {"document_ids": document_ids}
res = parse_documents(auth, dataset_id, parse_payload)
assert res["code"] == 0, f"Failed to start parsing: {res}"
# Wait for parsing completion for all documents
for doc_id in document_ids:
wait_for_parsing_completion(auth, dataset_id, doc_id)
return document_ids
finally:
# Clean up temporary files
for excel_file_path in excel_file_paths:
if excel_file_path:
os.unlink(excel_file_path)
@staticmethod
def _create_temp_excel_file(data):
"""
Create a temporary Excel file with the given table test data.
Args:
data: List of lists containing the Excel data
Returns:
str: Path to the created temporary file
"""
from openpyxl import Workbook
f = tempfile.NamedTemporaryFile(mode="wb", suffix=".xlsx", delete=False)
f.close()
wb = Workbook()
ws = wb.active
# Write test data to the worksheet
for row_idx, row_data in enumerate(data, start=1):
for col_idx, value in enumerate(row_data, start=1):
ws.cell(row=row_idx, column=col_idx, value=value)
wb.save(f.name)
return f.name
@staticmethod
def _create_chat_assistant_with_session(auth, dataset_id):
"""
Create a chat assistant and session for testing.
Returns:
tuple: (chat_id, session_id)
"""
import uuid
chat_payload = {
"name": f"test_table_parser_dataset_chat_{uuid.uuid4().hex[:8]}",
"dataset_ids": [dataset_id],
"prompt_config": {
"system": DEFAULT_CHAT_PROMPT,
"parameters": [
{
"key": "knowledge",
"optional": True,
"value": "Use the table data to answer questions with SQL queries.",
}
],
},
}
res = create_chat_assistant(auth, chat_payload)
assert res["code"] == 0, f"Failed to create chat assistant: {res}"
chat_id = res["data"]["id"]
res = create_session_with_chat_assistant(auth, chat_id, {"name": f"test_session_{uuid.uuid4().hex[:8]}"})
assert res["code"] == 0, f"Failed to create session: {res}"
session_id = res["data"]["id"]
return chat_id, session_id
def _ask_question(self, auth, chat_id, session_id, question):
"""
Send a question to the chat assistant and return the answer.
Returns:
str: The assistant's answer
"""
payload = {
"question": question,
"stream": False,
"session_id": session_id,
}
res_json = chat_completions(auth, chat_id, payload)
assert res_json["code"] == 0, f"Chat completion failed: {res_json}"
return res_json["data"]["answer"]
def _assert_answer_matches_pattern(self, answer, pattern):
"""
Assert that the answer matches the expected pattern.
Args:
answer: The actual answer from the chat assistant
pattern: Regular expression pattern to match
"""
assert re.search(pattern, answer, re.IGNORECASE), (
f"Answer does not match expected pattern '{pattern}'.\n"
f"Answer: {answer}"
)