mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-23 03:26:53 +08:00
Add dataset with table parser type for Infinity and answer question in chat using SQL (#12541)
### What problem does this PR solve? 1) Create dataset using table parser for infinity 2) Answer questions in chat using SQL ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -53,7 +53,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y ghostscript && \
|
||||
apt install -y pandoc && \
|
||||
apt install -y texlive && \
|
||||
apt install -y fonts-freefont-ttf fonts-noto-cjk
|
||||
apt install -y fonts-freefont-ttf fonts-noto-cjk && \
|
||||
apt install -y postgresql-client
|
||||
|
||||
# Install uv
|
||||
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \
|
||||
|
||||
@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode
|
||||
from api.apps import login_required, current_user
|
||||
import logging
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@ -69,6 +70,19 @@ async def set_dialog():
|
||||
meta_data_filter = req.get("meta_data_filter", {})
|
||||
prompt_config = req["prompt_config"]
|
||||
|
||||
# Set default parameters for datasets with knowledge retrieval
|
||||
# All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval
|
||||
kb_ids = req.get("kb_ids", [])
|
||||
parameters = prompt_config.get("parameters")
|
||||
logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}")
|
||||
# Check if parameters is missing, None, or empty list
|
||||
if kb_ids and not parameters:
|
||||
# Check if system prompt uses {knowledge} placeholder
|
||||
if "{knowledge}" in prompt_config.get("system", ""):
|
||||
# Set default parameters for any dataset with knowledge placeholder
|
||||
prompt_config["parameters"] = [{"key": "knowledge", "optional": False}]
|
||||
logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}")
|
||||
|
||||
if not is_create:
|
||||
# only for chat updating
|
||||
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
|
||||
|
||||
@ -295,12 +295,19 @@ async def rm():
|
||||
File.name == kbs[0].name,
|
||||
]
|
||||
)
|
||||
# Delete the table BEFORE deleting the database record
|
||||
for kb in kbs:
|
||||
try:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
|
||||
logging.info(f"Dropped index for dataset {kb.id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
|
||||
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -233,6 +233,15 @@ async def delete(tenant_id):
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
|
||||
|
||||
# Drop index for this dataset
|
||||
try:
|
||||
from rag.nlp import search
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
settings.docStoreConn.delete_idx(idxnm, kb_id)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to drop index for dataset {kb_id}: {e}")
|
||||
|
||||
if not KnowledgebaseService.delete_by_id(kb_id):
|
||||
errors.append(f"Delete dataset error for {kb_id}")
|
||||
continue
|
||||
|
||||
@ -37,7 +37,6 @@ from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.advanced_rag import DeepResearcher
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
@ -274,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
|
||||
|
||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
logging.debug("Begin async_chat")
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
async for ans in async_chat_solo(dialog, messages, stream):
|
||||
@ -323,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
logging.debug(f"field_map retrieved: {field_map}")
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
|
||||
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
|
||||
yield ans
|
||||
return
|
||||
else:
|
||||
logging.debug("SQL failed or returned no results, falling back to vector search")
|
||||
|
||||
param_keys = [p["key"] for p in prompt_config.get("parameters", [])]
|
||||
logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}")
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
@ -366,7 +373,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
knowledges = []
|
||||
|
||||
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
|
||||
if attachments is not None and "knowledge" in param_keys:
|
||||
logging.debug("Proceeding with retrieval")
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
knowledges = []
|
||||
if prompt_config.get("reasoning", False):
|
||||
@ -575,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
|
||||
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||
Ensure that:
|
||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
||||
2. Write only the SQL, no explanations or additional text.
|
||||
"""
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
{}
|
||||
logging.debug(f"use_sql: Question: {question}")
|
||||
|
||||
Question are as follows:
|
||||
# Determine which document engine we're using
|
||||
doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es"
|
||||
|
||||
# Construct the full table name
|
||||
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
|
||||
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
|
||||
base_table = index_name(tenant_id)
|
||||
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
|
||||
# Infinity: append kb_id to table name
|
||||
table_name = f"{base_table}_{kb_ids[0]}"
|
||||
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
|
||||
else:
|
||||
# Elasticsearch/OpenSearch: use base index name
|
||||
table_name = base_table
|
||||
logging.debug(f"use_sql: Using ES/OS table name: {table_name}")
|
||||
|
||||
# Generate engine-specific SQL prompts
|
||||
if doc_engine == "infinity":
|
||||
# Build Infinity prompts with JSON extraction context
|
||||
json_field_names = list(field_map.keys())
|
||||
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
|
||||
|
||||
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
|
||||
Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT)
|
||||
NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false
|
||||
|
||||
RULES:
|
||||
1. Use EXACT field names (case-sensitive) from the list below
|
||||
2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields
|
||||
3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...))
|
||||
4. Add AS alias for extracted field names
|
||||
5. DO NOT select 'content' field
|
||||
6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when:
|
||||
- Question asks to "show me" or "display" specific columns
|
||||
- Question mentions "not null" or "excluding null"
|
||||
- Add NULL check for count specific column
|
||||
- DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls)
|
||||
7. Output ONLY the SQL, no explanations"""
|
||||
user_prompt = """Table: {}
|
||||
Fields (EXACT case): {}
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||
Question: {}
|
||||
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
|
||||
table_name,
|
||||
", ".join(json_field_names),
|
||||
"\n".join([f" - {field}" for field in json_field_names]),
|
||||
question
|
||||
)
|
||||
else:
|
||||
# Build ES/OS prompts with direct field access
|
||||
sys_prompt = """You are a Database Administrator. Write SQL queries.
|
||||
|
||||
RULES:
|
||||
1. Use EXACT field names from the schema below (e.g., product_tks, not product)
|
||||
2. Quote field names starting with digit: "123_field"
|
||||
3. Add IS NOT NULL in WHERE clause when:
|
||||
- Question asks to "show me" or "display" specific columns
|
||||
4. Include doc_id/docnm in non-aggregate statement
|
||||
5. Output ONLY the SQL, no explanations"""
|
||||
user_prompt = """Table: {}
|
||||
Available fields:
|
||||
{}
|
||||
Question: {}
|
||||
Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
|
||||
table_name,
|
||||
"\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
|
||||
question
|
||||
)
|
||||
|
||||
tried_times = 0
|
||||
|
||||
async def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||||
sql = re.sub(r"&", "and", sql)
|
||||
if sql[: len("select ")] != "select ":
|
||||
return None, None
|
||||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||||
if sql[: len("select *")] != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
else:
|
||||
flds = []
|
||||
for k in field_map.keys():
|
||||
if k in forbidden_select_fields4resume:
|
||||
continue
|
||||
if len(flds) > 11:
|
||||
break
|
||||
flds.append(k)
|
||||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||||
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
|
||||
# Remove think blocks if present (format: </think>...)
|
||||
sql = re.sub(r"</think>\n.*?\n\s*", "", sql, flags=re.DOTALL)
|
||||
sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL)
|
||||
# Remove markdown code blocks (```sql ... ```)
|
||||
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
|
||||
# Remove trailing semicolon that ES SQL parser doesn't like
|
||||
sql = sql.rstrip().rstrip(';').strip()
|
||||
|
||||
if kb_ids:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
if "where" not in sql.lower():
|
||||
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
|
||||
if doc_engine != "infinity" and kb_ids:
|
||||
# Build kb_filter: single KB or multiple KBs with OR
|
||||
if len(kb_ids) == 1:
|
||||
kb_filter = f"kb_id = '{kb_ids[0]}'"
|
||||
else:
|
||||
kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
|
||||
|
||||
if "where " not in sql.lower():
|
||||
o = sql.lower().split("order by")
|
||||
if len(o) > 1:
|
||||
sql = o[0] + f" WHERE {kb_filter} order by " + o[1]
|
||||
else:
|
||||
sql += f" WHERE {kb_filter}"
|
||||
else:
|
||||
sql += f" AND {kb_filter}"
|
||||
elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower():
|
||||
sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE)
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})")
|
||||
tbl = settings.retriever.sql_retrieval(sql, format="json")
|
||||
if tbl is None:
|
||||
logging.debug("use_sql: SQL retrieval returned None")
|
||||
return None, sql
|
||||
logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows")
|
||||
return tbl, sql
|
||||
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}")
|
||||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}")
|
||||
except Exception as e:
|
||||
user_prompt = """
|
||||
logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}")
|
||||
# Build retry prompt with error information
|
||||
if doc_engine == "infinity":
|
||||
# Build Infinity error retry prompt
|
||||
json_field_names = list(field_map.keys())
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
JSON fields available in 'chunk_data' column (use these exact names in json_extract_string):
|
||||
{}
|
||||
|
||||
Question: {}
|
||||
Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations.
|
||||
""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e)
|
||||
else:
|
||||
# Build ES/OS error retry prompt
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
Table of database fields are as follows:
|
||||
Table of database fields are as follows (use the field names directly in SQL):
|
||||
{}
|
||||
|
||||
Question are as follows:
|
||||
{}
|
||||
Please write the SQL, only SQL, without any other explanations or text.
|
||||
Please write the SQL using the exact field names above, only SQL, without any other explanations or text.
|
||||
|
||||
|
||||
The SQL error you provided last time is as follows:
|
||||
{}
|
||||
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||
Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text.
|
||||
""".format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e)
|
||||
try:
|
||||
tbl, sql = await get_table()
|
||||
logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}")
|
||||
logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry")
|
||||
except Exception:
|
||||
logging.error("use_sql: Retry SQL execution also FAILED, returning None")
|
||||
return
|
||||
|
||||
if len(tbl["rows"]) == 0:
|
||||
logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}")
|
||||
return None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer")
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
|
||||
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
|
||||
|
||||
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
|
||||
logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
|
||||
|
||||
column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
||||
|
||||
logging.debug(f"use_sql: column_idx={column_idx}")
|
||||
logging.debug(f"use_sql: field_map={field_map}")
|
||||
|
||||
# Helper function to map column names to display names
|
||||
def map_column_name(col_name):
|
||||
if col_name.lower() == "count(star)":
|
||||
return "COUNT(*)"
|
||||
|
||||
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
|
||||
# Pattern: anything AS alias_name
|
||||
as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
|
||||
if as_match:
|
||||
alias = as_match.group(1).strip('"\'')
|
||||
|
||||
# Use the alias for display name lookup
|
||||
if alias in field_map:
|
||||
display = field_map[alias]
|
||||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||
# If alias not in field_map, try to match case-insensitively
|
||||
for field_key, display_value in field_map.items():
|
||||
if field_key.lower() == alias.lower():
|
||||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||
# Return alias as-is if no mapping found
|
||||
return alias
|
||||
|
||||
# Try direct mapping first (for simple column names)
|
||||
if col_name in field_map:
|
||||
display = field_map[col_name]
|
||||
# Clean up any suffix patterns
|
||||
return re.sub(r"(/.*|([^()]+))", "", display)
|
||||
|
||||
# Try case-insensitive match for simple column names
|
||||
col_lower = col_name.lower()
|
||||
for field_key, display_value in field_map.items():
|
||||
if field_key.lower() == col_lower:
|
||||
return re.sub(r"(/.*|([^()]+))", "", display_value)
|
||||
|
||||
# For aggregate expressions or complex expressions without AS alias,
|
||||
# try to replace field names with display names
|
||||
result = col_name
|
||||
for field_name, display_name in field_map.items():
|
||||
# Replace field_name with display_name in the expression
|
||||
result = result.replace(field_name, display_name)
|
||||
|
||||
# Clean up any suffix patterns
|
||||
result = re.sub(r"(/.*|([^()]+))", "", result)
|
||||
return result
|
||||
|
||||
# compose Markdown table
|
||||
columns = (
|
||||
"|" + "|".join(
|
||||
[re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and docid_idx else "|")
|
||||
[map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
|
||||
"|Source|" if docid_idx and doc_name_idx else "|")
|
||||
)
|
||||
|
||||
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
|
||||
rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||||
# Build rows ensuring column names match values - create a dict for each row
|
||||
# keyed by column name to handle any SQL column order
|
||||
rows = []
|
||||
for row_idx, r in enumerate(tbl["rows"]):
|
||||
row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}
|
||||
if row_idx == 0:
|
||||
logging.debug(f"use_sql: First row data: {row_dict}")
|
||||
row_values = []
|
||||
for col_idx in column_idx:
|
||||
col_name = tbl["columns"][col_idx]["name"]
|
||||
value = row_dict.get(col_name, " ")
|
||||
row_values.append(remove_redundant_spaces(str(value)).replace("None", " "))
|
||||
# Add Source column with citation marker if Source column exists
|
||||
if docid_idx and doc_name_idx:
|
||||
row_values.append(f" ##{row_idx}$$")
|
||||
row_str = "|" + "|".join(row_values) + "|"
|
||||
if re.sub(r"[ |]+", "", row_str):
|
||||
rows.append(row_str)
|
||||
if quota:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = "\n".join(rows)
|
||||
else:
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = "\n".join(rows)
|
||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||
|
||||
if not docid_idx or not doc_name_idx:
|
||||
logging.warning("SQL missing field: " + sql)
|
||||
logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}")
|
||||
# For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately
|
||||
# to provide source chunks, but keep the original table format answer
|
||||
if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()):
|
||||
# Keep original table format as answer
|
||||
answer = "\n".join([columns, line, rows])
|
||||
|
||||
# Now fetch doc_id, docnm_kwd to provide source chunks
|
||||
# Extract WHERE clause from the original SQL
|
||||
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
|
||||
if where_match:
|
||||
where_clause = where_match.group(1).strip()
|
||||
# Build a query to get doc_id and docnm_kwd with the same WHERE clause
|
||||
chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
|
||||
# Add LIMIT to avoid fetching too many chunks
|
||||
if "limit" not in chunks_sql.lower():
|
||||
chunks_sql += " limit 20"
|
||||
logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}")
|
||||
try:
|
||||
chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json")
|
||||
if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0:
|
||||
# Build chunks reference - use case-insensitive matching
|
||||
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
|
||||
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
|
||||
if chunks_did_idx is not None and chunks_dn_idx is not None:
|
||||
chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
|
||||
# Build doc_aggs
|
||||
doc_aggs = {}
|
||||
for r in chunks_tbl["rows"]:
|
||||
doc_id = r[chunks_did_idx]
|
||||
doc_name = r[chunks_dn_idx]
|
||||
if doc_id not in doc_aggs:
|
||||
doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0}
|
||||
doc_aggs[doc_id]["count"] += 1
|
||||
doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]
|
||||
logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents")
|
||||
return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt}
|
||||
except Exception as e:
|
||||
logging.warning(f"use_sql: Failed to fetch chunks: {e}")
|
||||
# Fallback: return answer without chunks
|
||||
return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
# Fallback to table format for other cases
|
||||
return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
|
||||
|
||||
docid_idx = list(docid_idx)[0]
|
||||
@ -690,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
if r[docid_idx] not in doc_aggs:
|
||||
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
||||
doc_aggs[r[docid_idx]]["count"] += 1
|
||||
return {
|
||||
|
||||
result = {
|
||||
"answer": "\n".join([columns, line, rows]),
|
||||
"reference": {
|
||||
"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
||||
@ -698,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
},
|
||||
"prompt": sys_prompt,
|
||||
}
|
||||
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
|
||||
return result
|
||||
|
||||
def clean_tts_text(text: str) -> str:
|
||||
if not text:
|
||||
|
||||
@ -1279,7 +1279,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
@ -164,7 +164,7 @@ class DocStoreConnection(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
|
||||
@ -123,7 +123,8 @@ class ESConnectionBase(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
|
||||
# parser_id is used by Infinity but not needed for ES (kept for interface compatibility)
|
||||
if self.index_exist(index_name, dataset_id):
|
||||
return True
|
||||
try:
|
||||
|
||||
@ -228,15 +228,26 @@ class InfinityConnectionBase(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
self.logger.debug(f"CREATE_IDX: Creating table {table_name}, parser_id: {parser_id}")
|
||||
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
# Use configured schema
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
|
||||
if parser_id is not None:
|
||||
from common.constants import ParserType
|
||||
if parser_id == ParserType.TABLE.value:
|
||||
# Table parser: add chunk_data JSON column to store table-specific fields
|
||||
schema["chunk_data"] = {"type": "json", "default": "{}"}
|
||||
self.logger.info("Added chunk_data column for TABLE parser")
|
||||
|
||||
vector_name = f"q_{vector_size}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vector_size},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
@ -453,4 +464,198 @@ class InfinityConnectionBase(DocStoreConnection):
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
"""
|
||||
Execute SQL query on Infinity database via psql command.
|
||||
Transform text-to-sql for Infinity's SQL syntax.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
self.logger.debug(f"InfinityConnection.sql get sql: {sql}")
|
||||
|
||||
# Clean up SQL
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
|
||||
# Transform SELECT field aliases to actual stored field names
|
||||
# Build field mapping from infinity_mapping.json comment field
|
||||
field_mapping = {}
|
||||
# Also build reverse mapping for column names in result
|
||||
reverse_mapping = {}
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if os.path.exists(fp_mapping):
|
||||
schema = json.load(open(fp_mapping))
|
||||
for field_name, field_info in schema.items():
|
||||
if "comment" in field_info:
|
||||
# Parse comma-separated aliases from comment
|
||||
# e.g., "docnm_kwd, title_tks, title_sm_tks"
|
||||
aliases = [a.strip() for a in field_info["comment"].split(",")]
|
||||
for alias in aliases:
|
||||
field_mapping[alias] = field_name
|
||||
reverse_mapping[field_name] = alias # Store first alias for reverse mapping
|
||||
|
||||
# Replace field names in SELECT clause
|
||||
select_match = re.search(r"(select\s+.*?)(from\s+)", sql, re.IGNORECASE)
|
||||
if select_match:
|
||||
select_clause = select_match.group(1)
|
||||
from_clause = select_match.group(2)
|
||||
|
||||
# Apply field transformations
|
||||
for alias, actual in field_mapping.items():
|
||||
select_clause = re.sub(
|
||||
rf'(^|[, ]){alias}([, ]|$)',
|
||||
rf'\1{actual}\2',
|
||||
select_clause
|
||||
)
|
||||
|
||||
sql = select_clause + from_clause + sql[select_match.end():]
|
||||
|
||||
# Also replace field names in WHERE, ORDER BY, GROUP BY, and HAVING clauses
|
||||
for alias, actual in field_mapping.items():
|
||||
# Transform in WHERE clause
|
||||
sql = re.sub(
|
||||
rf'(\bwhere\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
# Transform in ORDER BY clause
|
||||
sql = re.sub(
|
||||
rf'(\border by\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
# Transform in GROUP BY clause
|
||||
sql = re.sub(
|
||||
rf'(\bgroup by\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
# Transform in HAVING clause
|
||||
sql = re.sub(
|
||||
rf'(\bhaving\s+[^;]*?)(\b){re.escape(alias)}\b',
|
||||
rf'\1{actual}',
|
||||
sql,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
self.logger.debug(f"InfinityConnection.sql to execute: {sql}")
|
||||
|
||||
# Get connection parameters from the Infinity connection pool wrapper
|
||||
# We need to use INFINITY_CONN singleton, not the raw ConnectionPool
|
||||
from common.doc_store.infinity_conn_pool import INFINITY_CONN
|
||||
conn_info = INFINITY_CONN.get_conn_uri()
|
||||
|
||||
# Parse host and port from conn_info
|
||||
if conn_info and "host=" in conn_info:
|
||||
host_match = re.search(r"host=(\S+)", conn_info)
|
||||
if host_match:
|
||||
host = host_match.group(1)
|
||||
else:
|
||||
host = "infinity"
|
||||
else:
|
||||
host = "infinity"
|
||||
|
||||
# Parse port from conn_info, default to 5432 if not found
|
||||
if conn_info and "port=" in conn_info:
|
||||
port_match = re.search(r"port=(\d+)", conn_info)
|
||||
if port_match:
|
||||
port = port_match.group(1)
|
||||
else:
|
||||
port = "5432"
|
||||
else:
|
||||
port = "5432"
|
||||
|
||||
# Use psql command to execute SQL
|
||||
# Use full path to psql to avoid PATH issues
|
||||
psql_path = "/usr/bin/psql"
|
||||
# Check if psql exists at expected location, otherwise try to find it
|
||||
import shutil
|
||||
psql_from_path = shutil.which("psql")
|
||||
if psql_from_path:
|
||||
psql_path = psql_from_path
|
||||
|
||||
# Execute SQL with psql to get both column names and data in one call
|
||||
psql_cmd = [
|
||||
psql_path,
|
||||
"-h", host,
|
||||
"-p", port,
|
||||
"-c", sql,
|
||||
]
|
||||
|
||||
self.logger.debug(f"Executing psql command: {' '.join(psql_cmd)}")
|
||||
|
||||
result = subprocess.run(
|
||||
psql_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10 # 10 second timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_msg = result.stderr.strip()
|
||||
raise Exception(f"psql command failed: {error_msg}\nSQL: {sql}")
|
||||
|
||||
# Parse the output
|
||||
output = result.stdout.strip()
|
||||
if not output:
|
||||
# No results
|
||||
return {
|
||||
"columns": [],
|
||||
"rows": []
|
||||
} if format == "json" else []
|
||||
|
||||
# Parse psql table output which has format:
|
||||
# col1 | col2 | col3
|
||||
# -----+-----+-----
|
||||
# val1 | val2 | val3
|
||||
lines = output.split("\n")
|
||||
|
||||
# Extract column names from first line
|
||||
columns = []
|
||||
rows = []
|
||||
|
||||
if len(lines) >= 1:
|
||||
header_line = lines[0]
|
||||
for col_name in header_line.split("|"):
|
||||
col_name = col_name.strip()
|
||||
if col_name:
|
||||
columns.append({"name": col_name})
|
||||
|
||||
# Data starts after the separator line (line with dashes)
|
||||
data_start = 2 if len(lines) >= 2 and "-" in lines[1] else 1
|
||||
for i in range(data_start, len(lines)):
|
||||
line = lines[i].strip()
|
||||
# Skip empty lines and footer lines like "(1 row)"
|
||||
if not line or re.match(r"^\(\d+ row", line):
|
||||
continue
|
||||
# Split by | and strip each cell
|
||||
row = [cell.strip() for cell in line.split("|")]
|
||||
# Ensure row matches column count
|
||||
if len(row) == len(columns):
|
||||
rows.append(row)
|
||||
elif len(row) > len(columns):
|
||||
# Row has more cells than columns - truncate
|
||||
rows.append(row[:len(columns)])
|
||||
elif len(row) < len(columns):
|
||||
# Row has fewer cells - pad with empty strings
|
||||
rows.append(row + [""] * (len(columns) - len(row)))
|
||||
|
||||
if format == "json":
|
||||
result = {
|
||||
"columns": columns,
|
||||
"rows": rows[:fetch_size] if fetch_size > 0 else rows
|
||||
}
|
||||
else:
|
||||
result = rows[:fetch_size] if fetch_size > 0 else rows
|
||||
|
||||
return result
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
self.logger.exception(f"InfinityConnection.sql timeout. SQL:\n{sql}")
|
||||
raise Exception(f"SQL timeout\n\nSQL: {sql}")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"InfinityConnection.sql got exception. SQL:\n{sql}")
|
||||
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
|
||||
|
||||
@ -31,7 +31,11 @@ class InfinityConnectionPool:
|
||||
if hasattr(settings, "INFINITY"):
|
||||
self.INFINITY_CONFIG = settings.INFINITY
|
||||
else:
|
||||
self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
self.INFINITY_CONFIG = settings.get_base_config("infinity", {
|
||||
"uri": "infinity:23817",
|
||||
"postgres_port": 5432,
|
||||
"db_name": "default_db"
|
||||
})
|
||||
|
||||
infinity_uri = self.INFINITY_CONFIG["uri"]
|
||||
if ":" in infinity_uri:
|
||||
@ -61,6 +65,19 @@ class InfinityConnectionPool:
|
||||
def get_conn_pool(self):
|
||||
return self.conn_pool
|
||||
|
||||
def get_conn_uri(self):
|
||||
"""
|
||||
Get connection URI for PostgreSQL protocol.
|
||||
"""
|
||||
infinity_uri = self.INFINITY_CONFIG["uri"]
|
||||
postgres_port = self.INFINITY_CONFIG["postgres_port"]
|
||||
db_name = self.INFINITY_CONFIG["db_name"]
|
||||
|
||||
if ":" in infinity_uri:
|
||||
host, _ = infinity_uri.split(":")
|
||||
return f"host={host} port={postgres_port} dbname={db_name}"
|
||||
return f"host=localhost port={postgres_port} dbname={db_name}"
|
||||
|
||||
def refresh_conn_pool(self):
|
||||
try:
|
||||
inf_conn = self.conn_pool.get_conn()
|
||||
|
||||
@ -249,7 +249,11 @@ def init_settings():
|
||||
ES = get_base_config("es", {})
|
||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
elif lower_case_doc_engine == "infinity":
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
INFINITY = get_base_config("infinity", {
|
||||
"uri": "infinity:23817",
|
||||
"postgres_port": 5432,
|
||||
"db_name": "default_db"
|
||||
})
|
||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
elif lower_case_doc_engine == "opensearch":
|
||||
OS = get_base_config("os", {})
|
||||
@ -269,7 +273,11 @@ def init_settings():
|
||||
ES = get_base_config("es", {})
|
||||
msgStoreConn = memory_es_conn.ESConnection()
|
||||
elif DOC_ENGINE == "infinity":
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
INFINITY = get_base_config("infinity", {
|
||||
"uri": "infinity:23817",
|
||||
"postgres_port": 5432,
|
||||
"db_name": "default_db"
|
||||
})
|
||||
msgStoreConn = memory_infinity_conn.InfinityConnection()
|
||||
|
||||
global AZURE, S3, MINIO, OSS, GCS
|
||||
|
||||
@ -29,6 +29,7 @@ os:
|
||||
password: 'infini_rag_flow_OS_01'
|
||||
infinity:
|
||||
uri: 'localhost:23817'
|
||||
postgres_port: 5432
|
||||
db_name: 'default_db'
|
||||
oceanbase:
|
||||
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config
|
||||
|
||||
@ -29,6 +29,7 @@ os:
|
||||
password: '${OPENSEARCH_PASSWORD:-infini_rag_flow_OS_01}'
|
||||
infinity:
|
||||
uri: '${INFINITY_HOST:-infinity}:23817'
|
||||
postgres_port: 5432
|
||||
db_name: 'default_db'
|
||||
oceanbase:
|
||||
scheme: 'oceanbase' # set 'mysql' to create connection using mysql config
|
||||
|
||||
@ -33,6 +33,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_figure_xlsx_wrappe
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table
|
||||
from deepdoc.parser import ExcelParser
|
||||
from common import settings
|
||||
|
||||
|
||||
class Excel(ExcelParser):
|
||||
@ -431,7 +432,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
|
||||
res = []
|
||||
PY = Pinyin()
|
||||
fieds_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
|
||||
# Field type suffixes for database columns
|
||||
# Maps data types to their database field suffixes
|
||||
fields_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
|
||||
for df in dfs:
|
||||
for n in ["id", "_id", "index", "idx"]:
|
||||
if n in df.columns:
|
||||
@ -452,13 +455,24 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
df[clmns[j]] = cln
|
||||
if ty == "text":
|
||||
txts.extend([str(c) for c in cln if c])
|
||||
clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
|
||||
clmns_map = [(py_clmns[i].lower() + fields_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in
|
||||
range(len(clmns))]
|
||||
# For Infinity: Use original column names as keys since they're stored in chunk_data JSON
|
||||
# For ES/OS: Use full field names with type suffixes (e.g., url_kwd, body_tks)
|
||||
if settings.DOC_ENGINE_INFINITY:
|
||||
# For Infinity: key = original column name, value = display name
|
||||
field_map = {py_clmns[i].lower(): str(clmns[i]).replace("_", " ") for i in range(len(clmns))}
|
||||
else:
|
||||
# For ES/OS: key = typed field name, value = display name
|
||||
field_map = {k: v for k, v in clmns_map}
|
||||
logging.debug(f"Field map: {field_map}")
|
||||
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map})
|
||||
|
||||
eng = lang.lower() == "english" # is_english(txts)
|
||||
for ii, row in df.iterrows():
|
||||
d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||
row_txt = []
|
||||
row_fields = []
|
||||
data_json = {} # For Infinity: Store all columns in a JSON object
|
||||
for j in range(len(clmns)):
|
||||
if row[clmns[j]] is None:
|
||||
continue
|
||||
@ -466,17 +480,27 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
|
||||
continue
|
||||
if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]):
|
||||
continue
|
||||
fld = clmns_map[j][0]
|
||||
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]])
|
||||
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
|
||||
if not row_txt:
|
||||
# For Infinity: Store in chunk_data JSON column
|
||||
# For Elasticsearch/OpenSearch: Store as individual fields with type suffixes
|
||||
if settings.DOC_ENGINE_INFINITY:
|
||||
data_json[str(clmns[j])] = row[clmns[j]]
|
||||
else:
|
||||
fld = clmns_map[j][0]
|
||||
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]])
|
||||
row_fields.append((clmns[j], row[clmns[j]]))
|
||||
if not row_fields:
|
||||
continue
|
||||
tokenize(d, "; ".join(row_txt), eng)
|
||||
# Add the data JSON field to the document (for Infinity only)
|
||||
if settings.DOC_ENGINE_INFINITY:
|
||||
d["chunk_data"] = data_json
|
||||
# Format as a structured text for better LLM comprehension
|
||||
# Format each field as "- Field Name: Value" on separate lines
|
||||
formatted_text = "\n".join([f"- {field}: {value}" for field, value in row_fields])
|
||||
tokenize(d, formatted_text, eng)
|
||||
res.append(d)
|
||||
if tbls:
|
||||
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||
res.extend(tokenize_table(tbls, doc, is_english))
|
||||
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
|
||||
callback(0.35, "")
|
||||
|
||||
return res
|
||||
|
||||
@ -558,7 +558,8 @@ def build_TOC(task, docs, progress_callback):
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
parser_id = row.get("parser_id", None)
|
||||
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size, parser_id)
|
||||
|
||||
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -739,7 +740,7 @@ async def run_dataflow(task: dict):
|
||||
|
||||
start_ts = timer()
|
||||
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
|
||||
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
@ -833,7 +834,17 @@ async def delete_image(kb_id, chunk_id):
|
||||
raise
|
||||
|
||||
|
||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
"""
|
||||
Insert chunks into document store (Elasticsearch OR Infinity).
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
task_tenant_id: Tenant ID
|
||||
task_dataset_id: Dataset/knowledge base ID
|
||||
chunks: List of chunk dictionaries to insert
|
||||
progress_callback: Callback function for progress updates
|
||||
"""
|
||||
mothers = []
|
||||
mother_ids = set([])
|
||||
for ck in chunks:
|
||||
@ -858,7 +869,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
|
||||
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
|
||||
await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
search.index_name(task_tenant_id), task_dataset_id)
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -866,7 +877,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
|
||||
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
|
||||
doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
search.index_name(task_tenant_id), task_dataset_id)
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -932,13 +943,6 @@ async def do_handle_task(task):
|
||||
# prepare the progress callback function
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
|
||||
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
|
||||
lower_case_doc_engine = settings.DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
|
||||
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
|
||||
progress_callback(-1, msg=error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -1092,14 +1096,14 @@ async def do_handle_task(task):
|
||||
chunk_count = len(set([chunk["id"] for chunk in chunks]))
|
||||
start_ts = timer()
|
||||
|
||||
async def _maybe_insert_es(_chunks):
|
||||
async def _maybe_insert_chunks(_chunks):
|
||||
if has_canceled(task_id):
|
||||
return True
|
||||
insert_result = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
|
||||
insert_result = await insert_chunks(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
|
||||
return bool(insert_result)
|
||||
|
||||
try:
|
||||
if not await _maybe_insert_es(chunks):
|
||||
if not await _maybe_insert_chunks(chunks):
|
||||
return
|
||||
|
||||
logging.info(
|
||||
@ -1115,7 +1119,7 @@ async def do_handle_task(task):
|
||||
if toc_thread:
|
||||
d = toc_thread.result()
|
||||
if d:
|
||||
if not await _maybe_insert_es([d]):
|
||||
if not await _maybe_insert_chunks([d]):
|
||||
return
|
||||
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)
|
||||
|
||||
|
||||
@ -317,7 +317,18 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
break
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.create_idx(index_name, knowledgebase_id, vector_size)
|
||||
|
||||
# Determine parser_id from document structure
|
||||
# Table parser documents have 'chunk_data' field
|
||||
parser_id = None
|
||||
if "chunk_data" in documents[0] and isinstance(documents[0].get("chunk_data"), dict):
|
||||
from common.constants import ParserType
|
||||
parser_id = ParserType.TABLE.value
|
||||
self.logger.debug("Detected TABLE parser from document structure")
|
||||
|
||||
# Fallback: Create table with base schema (shouldn't normally happen as init_kb() creates it)
|
||||
self.logger.debug(f"Fallback: Creating table {table_name} with base schema, parser_id: {parser_id}")
|
||||
self.create_idx(index_name, knowledgebase_id, vector_size, parser_id)
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
# embedding fields can't have a default value....
|
||||
@ -378,6 +389,12 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
d[k] = v
|
||||
elif re.search(r"_feas$", k):
|
||||
d[k] = json.dumps(v)
|
||||
elif k == "chunk_data":
|
||||
# Convert data dict to JSON string for storage
|
||||
if isinstance(v, dict):
|
||||
d[k] = json.dumps(v)
|
||||
else:
|
||||
d[k] = v
|
||||
elif k == "kb_id":
|
||||
if isinstance(d[k], list):
|
||||
d[k] = d[k][0] # since d[k] is a list, but we need a str
|
||||
@ -586,6 +603,9 @@ class InfinityConnection(InfinityConnectionBase):
|
||||
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
|
||||
elif re.search(r"_feas$", k):
|
||||
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
|
||||
elif k == "chunk_data":
|
||||
# Parse JSON data back to dict for table parser fields
|
||||
res2[column] = res2[column].apply(lambda v: json.loads(v) if v and isinstance(v, str) else v)
|
||||
elif k == "position_int":
|
||||
def to_position_int(v):
|
||||
if v:
|
||||
|
||||
@ -49,6 +49,11 @@ def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None
|
||||
|
||||
|
||||
def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
"""
|
||||
Delete datasets.
|
||||
The endpoint is DELETE /api/{VERSION}/datasets with payload {"ids": [...]}
|
||||
This is the standard SDK REST API endpoint for dataset deletion.
|
||||
"""
|
||||
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
@ -300,12 +305,6 @@ def metadata_summary(auth, dataset_id, params=None):
|
||||
|
||||
|
||||
# CHAT COMPLETIONS AND RELATED QUESTIONS
|
||||
def chat_completions(auth, chat_assistant_id, payload=None):
|
||||
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}/completions"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def related_questions(auth, payload=None):
|
||||
url = f"{HOST_ADDRESS}/api/{VERSION}/sessions/related_questions"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
@ -355,3 +354,23 @@ def agent_completions(auth, agent_id, payload=None):
|
||||
url = f"{HOST_ADDRESS}{AGENT_API_URL}/{agent_id}/completions"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def chat_completions(auth, chat_id, payload=None):
|
||||
"""
|
||||
Send a question/message to a chat assistant and get completion.
|
||||
|
||||
Args:
|
||||
auth: Authentication object
|
||||
chat_id: Chat assistant ID
|
||||
payload: Dictionary containing:
|
||||
- question: str (required) - The question to ask
|
||||
- stream: bool (optional) - Whether to stream responses, default False
|
||||
- session_id: str (optional) - Session ID for conversation context
|
||||
|
||||
Returns:
|
||||
Response JSON with answer data
|
||||
"""
|
||||
url = f"{HOST_ADDRESS}/api/{VERSION}/chats/{chat_id}/completions"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
@ -0,0 +1,42 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import pytest
|
||||
from common import create_dataset, delete_datasets
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_table_parser_dataset(HttpApiAuth, request):
|
||||
"""
|
||||
Fixture to create a table parser dataset for testing.
|
||||
Automatically cleans up after tests complete (deletes dataset and table).
|
||||
Note: field_map is automatically generated by the table parser when processing files.
|
||||
"""
|
||||
dataset_payload = {
|
||||
"name": "test_table_parser_dataset",
|
||||
"chunk_method": "table", # table parser
|
||||
}
|
||||
res = create_dataset(HttpApiAuth, dataset_payload)
|
||||
assert res["code"] == 0, f"Failed to create dataset: {res}"
|
||||
dataset_id = res["data"]["id"]
|
||||
|
||||
def cleanup():
|
||||
delete_datasets(HttpApiAuth, {"ids": [dataset_id]})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
return dataset_id
|
||||
@ -0,0 +1,324 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from common import (
|
||||
chat_completions,
|
||||
create_chat_assistant,
|
||||
create_session_with_chat_assistant,
|
||||
delete_chat_assistants,
|
||||
list_documents,
|
||||
upload_documents,
|
||||
parse_documents,
|
||||
)
|
||||
from utils import wait_for
|
||||
|
||||
@wait_for(200, 1, "Document parsing timeout")
|
||||
def wait_for_parsing_completion(auth, dataset_id, document_id=None):
|
||||
"""
|
||||
Wait for document parsing to complete.
|
||||
|
||||
Args:
|
||||
auth: Authentication object
|
||||
dataset_id: Dataset ID
|
||||
document_id: Optional specific document ID to wait for
|
||||
|
||||
Returns:
|
||||
bool: True if parsing is complete, False otherwise
|
||||
"""
|
||||
res = list_documents(auth, dataset_id)
|
||||
docs = res["data"]["docs"]
|
||||
|
||||
if document_id is None:
|
||||
# Wait for all documents to complete
|
||||
for doc in docs:
|
||||
status = doc.get("run", "UNKNOWN")
|
||||
if status != "DONE":
|
||||
print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}")
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
# Wait for specific document
|
||||
for doc in docs:
|
||||
if doc["id"] == document_id:
|
||||
status = doc.get("run", "UNKNOWN")
|
||||
print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}")
|
||||
if status == "DONE":
|
||||
return True
|
||||
elif status == "FAILED":
|
||||
pytest.fail(f"Document parsing failed: {doc}")
|
||||
return False
|
||||
return False
|
||||
|
||||
# Test data
|
||||
TEST_EXCEL_DATA = [
|
||||
["employee_id", "name", "department", "salary"],
|
||||
["E001", "Alice Johnson", "Engineering", "95000"],
|
||||
["E002", "Bob Smith", "Marketing", "65000"],
|
||||
["E003", "Carol Williams", "Engineering", "88000"],
|
||||
["E004", "David Brown", "Sales", "72000"],
|
||||
["E005", "Eva Davis", "HR", "68000"],
|
||||
["E006", "Frank Miller", "Engineering", "102000"],
|
||||
]
|
||||
|
||||
TEST_EXCEL_DATA_2 = [
|
||||
["product", "price", "category"],
|
||||
["Laptop", "999", "Electronics"],
|
||||
["Mouse", "29", "Electronics"],
|
||||
["Desk", "299", "Furniture"],
|
||||
["Chair", "199", "Furniture"],
|
||||
["Monitor", "399", "Electronics"],
|
||||
["Keyboard", "79", "Electronics"],
|
||||
]
|
||||
|
||||
DEFAULT_CHAT_PROMPT = (
|
||||
"You are a helpful assistant that answers questions about table data using SQL queries.\n\n"
|
||||
"Here is the knowledge base:\n{knowledge}\n\n"
|
||||
"Use this information to answer questions."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_table_parser_dataset")
|
||||
class TestTableParserDatasetChat:
|
||||
"""
|
||||
Test table parser dataset chat functionality with Infinity backend.
|
||||
|
||||
Verifies that:
|
||||
1. Excel files are uploaded and parsed correctly into table parser datasets
|
||||
2. Chat assistants can query the parsed table data via SQL
|
||||
3. Different types of queries work
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_chat_assistant(self, HttpApiAuth, add_table_parser_dataset, request):
|
||||
"""
|
||||
Setup fixture that runs before each test method.
|
||||
Creates chat assistant once and reuses it across all test cases.
|
||||
"""
|
||||
# Only setup once (first time)
|
||||
if not hasattr(self.__class__, 'chat_id'):
|
||||
self.__class__.dataset_id = add_table_parser_dataset
|
||||
self.__class__.auth = HttpApiAuth
|
||||
|
||||
# Upload and parse Excel files once for all tests
|
||||
self._upload_and_parse_excel(HttpApiAuth, add_table_parser_dataset)
|
||||
|
||||
# Create a single chat assistant and session for all tests
|
||||
chat_id, session_id = self._create_chat_assistant_with_session(
|
||||
HttpApiAuth, add_table_parser_dataset
|
||||
)
|
||||
self.__class__.chat_id = chat_id
|
||||
self.__class__.session_id = session_id
|
||||
|
||||
# Store the total number of parametrize cases
|
||||
mark = request.node.get_closest_marker('parametrize')
|
||||
if mark:
|
||||
# Get the number of test cases from parametrize
|
||||
param_values = mark.args[1]
|
||||
self.__class__._total_tests = len(param_values)
|
||||
else:
|
||||
self.__class__._total_tests = 1
|
||||
|
||||
yield
|
||||
|
||||
# Teardown: cleanup chat assistant after all tests
|
||||
# Use a class-level counter to track tests
|
||||
if not hasattr(self.__class__, '_test_counter'):
|
||||
self.__class__._test_counter = 0
|
||||
self.__class__._test_counter += 1
|
||||
|
||||
# Cleanup after all parametrize tests complete
|
||||
if self.__class__._test_counter >= self.__class__._total_tests:
|
||||
self._teardown_chat_assistant()
|
||||
|
||||
def _teardown_chat_assistant(self):
|
||||
"""Teardown method to clean up chat assistant."""
|
||||
if hasattr(self.__class__, 'chat_id') and self.__class__.chat_id:
|
||||
try:
|
||||
delete_chat_assistants(self.__class__.auth, {"ids": [self.__class__.chat_id]})
|
||||
except Exception as e:
|
||||
print(f"[Teardown] Warning: Failed to delete chat assistant: {e}")
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"question, expected_answer_pattern",
|
||||
[
|
||||
("show me column of product", r"\|product\|Source"),
|
||||
("which product has price 79", r"Keyboard"),
|
||||
("How many rows in the dataset?", r"count\(\*\)"),
|
||||
("Show me all employees in Engineering department", r"(Alice|Carol|Frank)"),
|
||||
],
|
||||
)
|
||||
def test_table_parser_dataset_chat(self, question, expected_answer_pattern):
|
||||
"""
|
||||
Test that table parser dataset chat works correctly.
|
||||
"""
|
||||
# Use class-level attributes (set by setup fixture)
|
||||
answer = self._ask_question(
|
||||
self.__class__.auth,
|
||||
self.__class__.chat_id,
|
||||
self.__class__.session_id,
|
||||
question
|
||||
)
|
||||
|
||||
# Verify answer matches expected pattern if provided
|
||||
if expected_answer_pattern:
|
||||
self._assert_answer_matches_pattern(answer, expected_answer_pattern)
|
||||
else:
|
||||
# Just verify we got a non-empty answer
|
||||
assert answer and len(answer) > 0, "Expected non-empty answer"
|
||||
|
||||
print(f"[Test] Question: {question}")
|
||||
print(f"[Test] Answer: {answer[:100]}...")
|
||||
|
||||
@staticmethod
|
||||
def _upload_and_parse_excel(auth, dataset_id):
|
||||
"""
|
||||
Upload 2 Excel files and wait for parsing to complete.
|
||||
|
||||
Returns:
|
||||
list: The document IDs of the uploaded files
|
||||
|
||||
Raises:
|
||||
AssertionError: If upload or parsing fails
|
||||
"""
|
||||
excel_file_paths = []
|
||||
document_ids = []
|
||||
try:
|
||||
# Create 2 temporary Excel files
|
||||
excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA))
|
||||
excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA_2))
|
||||
|
||||
# Upload documents
|
||||
res = upload_documents(auth, dataset_id, excel_file_paths)
|
||||
assert res["code"] == 0, f"Failed to upload documents: {res}"
|
||||
|
||||
for doc in res["data"]:
|
||||
document_ids.append(doc["id"])
|
||||
|
||||
# Start parsing for all documents
|
||||
parse_payload = {"document_ids": document_ids}
|
||||
res = parse_documents(auth, dataset_id, parse_payload)
|
||||
assert res["code"] == 0, f"Failed to start parsing: {res}"
|
||||
|
||||
# Wait for parsing completion for all documents
|
||||
for doc_id in document_ids:
|
||||
wait_for_parsing_completion(auth, dataset_id, doc_id)
|
||||
|
||||
return document_ids
|
||||
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
for excel_file_path in excel_file_paths:
|
||||
if excel_file_path:
|
||||
os.unlink(excel_file_path)
|
||||
|
||||
@staticmethod
|
||||
def _create_temp_excel_file(data):
|
||||
"""
|
||||
Create a temporary Excel file with the given table test data.
|
||||
|
||||
Args:
|
||||
data: List of lists containing the Excel data
|
||||
|
||||
Returns:
|
||||
str: Path to the created temporary file
|
||||
"""
|
||||
from openpyxl import Workbook
|
||||
|
||||
f = tempfile.NamedTemporaryFile(mode="wb", suffix=".xlsx", delete=False)
|
||||
f.close()
|
||||
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
|
||||
# Write test data to the worksheet
|
||||
for row_idx, row_data in enumerate(data, start=1):
|
||||
for col_idx, value in enumerate(row_data, start=1):
|
||||
ws.cell(row=row_idx, column=col_idx, value=value)
|
||||
|
||||
wb.save(f.name)
|
||||
return f.name
|
||||
|
||||
@staticmethod
|
||||
def _create_chat_assistant_with_session(auth, dataset_id):
|
||||
"""
|
||||
Create a chat assistant and session for testing.
|
||||
|
||||
Returns:
|
||||
tuple: (chat_id, session_id)
|
||||
"""
|
||||
import uuid
|
||||
|
||||
chat_payload = {
|
||||
"name": f"test_table_parser_dataset_chat_{uuid.uuid4().hex[:8]}",
|
||||
"dataset_ids": [dataset_id],
|
||||
"prompt_config": {
|
||||
"system": DEFAULT_CHAT_PROMPT,
|
||||
"parameters": [
|
||||
{
|
||||
"key": "knowledge",
|
||||
"optional": True,
|
||||
"value": "Use the table data to answer questions with SQL queries.",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
res = create_chat_assistant(auth, chat_payload)
|
||||
assert res["code"] == 0, f"Failed to create chat assistant: {res}"
|
||||
chat_id = res["data"]["id"]
|
||||
|
||||
res = create_session_with_chat_assistant(auth, chat_id, {"name": f"test_session_{uuid.uuid4().hex[:8]}"})
|
||||
assert res["code"] == 0, f"Failed to create session: {res}"
|
||||
session_id = res["data"]["id"]
|
||||
|
||||
return chat_id, session_id
|
||||
|
||||
def _ask_question(self, auth, chat_id, session_id, question):
|
||||
"""
|
||||
Send a question to the chat assistant and return the answer.
|
||||
|
||||
Returns:
|
||||
str: The assistant's answer
|
||||
"""
|
||||
payload = {
|
||||
"question": question,
|
||||
"stream": False,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
res_json = chat_completions(auth, chat_id, payload)
|
||||
assert res_json["code"] == 0, f"Chat completion failed: {res_json}"
|
||||
|
||||
return res_json["data"]["answer"]
|
||||
|
||||
def _assert_answer_matches_pattern(self, answer, pattern):
|
||||
"""
|
||||
Assert that the answer matches the expected pattern.
|
||||
|
||||
Args:
|
||||
answer: The actual answer from the chat assistant
|
||||
pattern: Regular expression pattern to match
|
||||
"""
|
||||
assert re.search(pattern, answer, re.IGNORECASE), (
|
||||
f"Answer does not match expected pattern '{pattern}'.\n"
|
||||
f"Answer: {answer}"
|
||||
)
|
||||
Reference in New Issue
Block a user