From b40d639fdbf4ba65af684e1c68f4f013c023535d Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Mon, 19 Jan 2026 19:35:14 +0800 Subject: [PATCH] Add dataset with table parser type for Infinity and answer question in chat using SQL (#12541) ### What problem does this PR solve? 1) Create dataset using table parser for infinity 2) Answer questions in chat using SQL ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- Dockerfile | 3 +- api/apps/dialog_app.py | 14 + api/apps/kb_app.py | 11 +- api/apps/sdk/dataset.py | 9 + api/db/services/dialog_service.py | 321 +++++++++++++---- api/db/services/document_service.py | 2 +- common/doc_store/doc_store_base.py | 2 +- common/doc_store/es_conn_base.py | 3 +- common/doc_store/infinity_conn_base.py | 209 ++++++++++- common/doc_store/infinity_conn_pool.py | 19 +- common/settings.py | 12 +- conf/service_conf.yaml | 1 + docker/service_conf.yaml.template | 1 + rag/app/table.py | 42 ++- rag/svr/task_executor.py | 36 +- rag/utils/infinity_conn.py | 22 +- test/testcases/test_http_api/common.py | 31 +- .../test_chat_management/conftest.py | 42 +++ .../test_table_parser_dataset_chat.py | 324 ++++++++++++++++++ 19 files changed, 1003 insertions(+), 101 deletions(-) create mode 100644 test/testcases/test_http_api/test_chat_management/conftest.py create mode 100644 test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py diff --git a/Dockerfile b/Dockerfile index 47ef161e4..1da884343 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,7 +53,8 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ apt install -y ghostscript && \ apt install -y pandoc && \ apt install -y texlive && \ - apt install -y fonts-freefont-ttf fonts-noto-cjk + apt install -y fonts-freefont-ttf fonts-noto-cjk && \ + apt install -y postgresql-client # Install uv RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \ diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 33502f402..9b7617797 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, get_requ from common.misc_utils import get_uuid from common.constants import RetCode from api.apps import login_required, current_user +import logging @manager.route('/set', methods=['POST']) # noqa: F821 @@ -69,6 +70,19 @@ async def set_dialog(): meta_data_filter = req.get("meta_data_filter", {}) prompt_config = req["prompt_config"] + # Set default parameters for datasets with knowledge retrieval + # All datasets with {knowledge} in system prompt need "knowledge" parameter to enable retrieval + kb_ids = req.get("kb_ids", []) + parameters = prompt_config.get("parameters") + logging.debug(f"set_dialog: kb_ids={kb_ids}, parameters={parameters}, is_create={not is_create}") + # Check if parameters is missing, None, or empty list + if kb_ids and not parameters: + # Check if system prompt uses {knowledge} placeholder + if "{knowledge}" in prompt_config.get("system", ""): + # Set default parameters for any dataset with knowledge placeholder + prompt_config["parameters"] = [{"key": "knowledge", "optional": False}] + logging.debug(f"Set default parameters for datasets with knowledge placeholder: {kb_ids}") + if not is_create: # only for chat updating if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""): diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index a35345feb..e7d86594d 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -295,12 +295,19 @@ async def rm(): File.name == kbs[0].name, ] ) + # Delete the table BEFORE deleting the database record + for kb in kbs: + try: + settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id) + logging.info(f"Dropped index for dataset {kb.id}") + except Exception as e: + logging.error(f"Failed to drop index for dataset {kb.id}: {e}") + if not KnowledgebaseService.delete_by_id(req["kb_id"]): return get_data_error_result( message="Database error (Knowledgebase removal)!") for kb in kbs: - settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) - settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id) if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): settings.STORAGE_IMPL.remove_bucket(kb.id) return get_json_result(data=True) diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index f98705de0..d0d7ff0c6 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -233,6 +233,15 @@ async def delete(tenant_id): File2DocumentService.delete_by_document_id(doc.id) FileService.filter_delete( [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) + + # Drop index for this dataset + try: + from rag.nlp import search + idxnm = search.index_name(kb.tenant_id) + settings.docStoreConn.delete_idx(idxnm, kb_id) + except Exception as e: + logging.warning(f"Failed to drop index for dataset {kb_id}: {e}") + if not KnowledgebaseService.delete_by_id(kb_id): errors.append(f"Delete dataset error for {kb_id}") continue diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ccf8474b6..707227653 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -37,7 +37,6 @@ from api.db.services.tenant_llm_service import TenantLLMService from common.time_utils import current_timestamp, datetime_format from graphrag.general.mind_map_extractor import MindMapExtractor from rag.advanced_rag import DeepResearcher -from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ @@ -274,6 +273,7 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): async def async_chat(dialog, messages, stream=True, **kwargs): + logging.debug("Begin async_chat") assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): async for ans in async_chat_solo(dialog, messages, stream): @@ -323,13 +323,20 @@ async def async_chat(dialog, messages, stream=True, **kwargs): prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) + logging.debug(f"field_map retrieved: {field_map}") # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) - if ans: + # For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid + if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")): yield ans return + else: + logging.debug("SQL failed or returned no results, falling back to vector search") + + param_keys = [p["key"] for p in prompt_config.get("parameters", [])] + logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}") for p in prompt_config["parameters"]: if p["key"] == "knowledge": @@ -366,7 +373,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs): kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} knowledges = [] - if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]: + if attachments is not None and "knowledge" in param_keys: + logging.debug("Proceeding with retrieval") tenant_ids = list(set([kb.tenant_id for kb in kbs])) knowledges = [] if prompt_config.get("reasoning", False): @@ -575,112 +583,306 @@ async def async_chat(dialog, messages, stream=True, **kwargs): async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): - sys_prompt = """ -You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. -Ensure that: -1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it. -2. Write only the SQL, no explanations or additional text. -""" - user_prompt = """ -Table name: {}; -Table of database fields are as follows: -{} + logging.debug(f"use_sql: Question: {question}") -Question are as follows: + # Determine which document engine we're using + doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es" + + # Construct the full table name + # For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause) + # For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table) + base_table = index_name(tenant_id) + if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1: + # Infinity: append kb_id to table name + table_name = f"{base_table}_{kb_ids[0]}" + logging.debug(f"use_sql: Using Infinity table name: {table_name}") + else: + # Elasticsearch/OpenSearch: use base index name + table_name = base_table + logging.debug(f"use_sql: Using ES/OS table name: {table_name}") + + # Generate engine-specific SQL prompts + if doc_engine == "infinity": + # Build Infinity prompts with JSON extraction context + json_field_names = list(field_map.keys()) + sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. + +JSON Extraction: json_extract_string(chunk_data, '$.FieldName') +Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT) +NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false + +RULES: +1. Use EXACT field names (case-sensitive) from the list below +2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields +3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...)) +4. Add AS alias for extracted field names +5. DO NOT select 'content' field +6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when: + - Question asks to "show me" or "display" specific columns + - Question mentions "not null" or "excluding null" + - Add NULL check for count specific column + - DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls) +7. Output ONLY the SQL, no explanations""" + user_prompt = """Table: {} +Fields (EXACT case): {} {} -Please write the SQL, only SQL, without any other explanations or text. -""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) +Question: {} +Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format( + table_name, + ", ".join(json_field_names), + "\n".join([f" - {field}" for field in json_field_names]), + question + ) + else: + # Build ES/OS prompts with direct field access + sys_prompt = """You are a Database Administrator. Write SQL queries. + +RULES: +1. Use EXACT field names from the schema below (e.g., product_tks, not product) +2. Quote field names starting with digit: "123_field" +3. Add IS NOT NULL in WHERE clause when: + - Question asks to "show me" or "display" specific columns +4. Include doc_id/docnm in non-aggregate statement +5. Output ONLY the SQL, no explanations""" + user_prompt = """Table: {} +Available fields: +{} +Question: {} +Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( + table_name, + "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), + question + ) + tried_times = 0 async def get_table(): nonlocal sys_prompt, user_prompt, question, tried_times sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) - sql = re.sub(r"^.*", "", sql, flags=re.DOTALL) - logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") - sql = re.sub(r"[\r\n]+", " ", sql.lower()) - sql = re.sub(r".*select ", "select ", sql.lower()) - sql = re.sub(r" +", " ", sql) - sql = re.sub(r"([;;]|```).*", "", sql) - sql = re.sub(r"&", "and", sql) - if sql[: len("select ")] != "select ": - return None, None - if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): - if sql[: len("select *")] != "select *": - sql = "select doc_id,docnm_kwd," + sql[6:] - else: - flds = [] - for k in field_map.keys(): - if k in forbidden_select_fields4resume: - continue - if len(flds) > 11: - break - flds.append(k) - sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] + logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") + # Remove think blocks if present (format: ...) + sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) + sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL) + # Remove markdown code blocks (```sql ... ```) + sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE) + sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE) + # Remove trailing semicolon that ES SQL parser doesn't like + sql = sql.rstrip().rstrip(';').strip() - if kb_ids: - kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" - if "where" not in sql.lower(): + # Add kb_id filter for ES/OS only (Infinity already has it in table name) + if doc_engine != "infinity" and kb_ids: + # Build kb_filter: single KB or multiple KBs with OR + if len(kb_ids) == 1: + kb_filter = f"kb_id = '{kb_ids[0]}'" + else: + kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" + + if "where " not in sql.lower(): o = sql.lower().split("order by") if len(o) > 1: sql = o[0] + f" WHERE {kb_filter} order by " + o[1] else: sql += f" WHERE {kb_filter}" - else: - sql += f" AND {kb_filter}" + elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower(): + sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE) logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 - return settings.retriever.sql_retrieval(sql, format="json"), sql + logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})") + tbl = settings.retriever.sql_retrieval(sql, format="json") + if tbl is None: + logging.debug("use_sql: SQL retrieval returned None") + return None, sql + logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows") + return tbl, sql try: tbl, sql = await get_table() + logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}") + logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}") except Exception as e: - user_prompt = """ + logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}") + # Build retry prompt with error information + if doc_engine == "infinity": + # Build Infinity error retry prompt + json_field_names = list(field_map.keys()) + user_prompt = """ +Table name: {}; +JSON fields available in 'chunk_data' column (use these exact names in json_extract_string): +{} + +Question: {} +Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations. + + +The SQL error you provided last time is as follows: +{} + +Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations. +""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e) + else: + # Build ES/OS error retry prompt + user_prompt = """ Table name: {}; - Table of database fields are as follows: + Table of database fields are as follows (use the field names directly in SQL): {} Question are as follows: {} - Please write the SQL, only SQL, without any other explanations or text. + Please write the SQL using the exact field names above, only SQL, without any other explanations or text. The SQL error you provided last time is as follows: {} - Please correct the error and write SQL again, only SQL, without any other explanations or text. - """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e) + Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text. + """.format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e) try: tbl, sql = await get_table() + logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}") + logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry") except Exception: + logging.error("use_sql: Retry SQL execution also FAILED, returning None") return if len(tbl["rows"]) == 0: + logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}") return None - docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) - doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) + logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer") + + docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"]) + doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]]) + + logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}") + logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}") + column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] + logging.debug(f"use_sql: column_idx={column_idx}") + logging.debug(f"use_sql: field_map={field_map}") + + # Helper function to map column names to display names + def map_column_name(col_name): + if col_name.lower() == "count(star)": + return "COUNT(*)" + + # First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.) + # Pattern: anything AS alias_name + as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE) + if as_match: + alias = as_match.group(1).strip('"\'') + + # Use the alias for display name lookup + if alias in field_map: + display = field_map[alias] + return re.sub(r"(/.*|([^()]+))", "", display) + # If alias not in field_map, try to match case-insensitively + for field_key, display_value in field_map.items(): + if field_key.lower() == alias.lower(): + return re.sub(r"(/.*|([^()]+))", "", display_value) + # Return alias as-is if no mapping found + return alias + + # Try direct mapping first (for simple column names) + if col_name in field_map: + display = field_map[col_name] + # Clean up any suffix patterns + return re.sub(r"(/.*|([^()]+))", "", display) + + # Try case-insensitive match for simple column names + col_lower = col_name.lower() + for field_key, display_value in field_map.items(): + if field_key.lower() == col_lower: + return re.sub(r"(/.*|([^()]+))", "", display_value) + + # For aggregate expressions or complex expressions without AS alias, + # try to replace field names with display names + result = col_name + for field_name, display_name in field_map.items(): + # Replace field_name with display_name in the expression + result = result.replace(field_name, display_name) + + # Clean up any suffix patterns + result = re.sub(r"(/.*|([^()]+))", "", result) + return result + # compose Markdown table columns = ( "|" + "|".join( - [re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ( - "|Source|" if docid_idx and docid_idx else "|") + [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ( + "|Source|" if docid_idx and doc_name_idx else "|") ) line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") - rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]] - rows = [r for r in rows if re.sub(r"[ |]+", "", r)] + # Build rows ensuring column names match values - create a dict for each row + # keyed by column name to handle any SQL column order + rows = [] + for row_idx, r in enumerate(tbl["rows"]): + row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)} + if row_idx == 0: + logging.debug(f"use_sql: First row data: {row_dict}") + row_values = [] + for col_idx in column_idx: + col_name = tbl["columns"][col_idx]["name"] + value = row_dict.get(col_name, " ") + row_values.append(remove_redundant_spaces(str(value)).replace("None", " ")) + # Add Source column with citation marker if Source column exists + if docid_idx and doc_name_idx: + row_values.append(f" ##{row_idx}$$") + row_str = "|" + "|".join(row_values) + "|" + if re.sub(r"[ |]+", "", row_str): + rows.append(row_str) if quota: - rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + rows = "\n".join(rows) else: - rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + rows = "\n".join(rows) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) if not docid_idx or not doc_name_idx: - logging.warning("SQL missing field: " + sql) + logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}") + # For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately + # to provide source chunks, but keep the original table format answer + if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()): + # Keep original table format as answer + answer = "\n".join([columns, line, rows]) + + # Now fetch doc_id, docnm_kwd to provide source chunks + # Extract WHERE clause from the original SQL + where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE) + if where_match: + where_clause = where_match.group(1).strip() + # Build a query to get doc_id and docnm_kwd with the same WHERE clause + chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}" + # Add LIMIT to avoid fetching too many chunks + if "limit" not in chunks_sql.lower(): + chunks_sql += " limit 20" + logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}") + try: + chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json") + if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0: + # Build chunks reference - use case-insensitive matching + chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None) + chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None) + if chunks_did_idx is not None and chunks_dn_idx is not None: + chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]] + # Build doc_aggs + doc_aggs = {} + for r in chunks_tbl["rows"]: + doc_id = r[chunks_did_idx] + doc_name = r[chunks_dn_idx] + if doc_id not in doc_aggs: + doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0} + doc_aggs[doc_id]["count"] += 1 + doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()] + logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents") + return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt} + except Exception as e: + logging.warning(f"use_sql: Failed to fetch chunks: {e}") + # Fallback: return answer without chunks + return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt} + # Fallback to table format for other cases return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt} docid_idx = list(docid_idx)[0] @@ -690,7 +892,8 @@ Please write the SQL, only SQL, without any other explanations or text. if r[docid_idx] not in doc_aggs: doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0} doc_aggs[r[docid_idx]]["count"] += 1 - return { + + result = { "answer": "\n".join([columns, line, rows]), "reference": { "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], @@ -698,6 +901,8 @@ Please write the SQL, only SQL, without any other explanations or text. }, "prompt": sys_prompt, } + logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents") + return result def clean_tts_text(text: str) -> str: if not text: diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index ef1b831aa..896d97c77 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -1279,7 +1279,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): for b in range(0, len(cks), es_bulk_size): if try_create_idx: if not settings.docStoreConn.index_exist(idxnm, kb_id): - settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0])) + settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id) try_create_idx = False settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) diff --git a/common/doc_store/doc_store_base.py b/common/doc_store/doc_store_base.py index fe6304f75..fd684baef 100644 --- a/common/doc_store/doc_store_base.py +++ b/common/doc_store/doc_store_base.py @@ -164,7 +164,7 @@ class DocStoreConnection(ABC): """ @abstractmethod - def create_idx(self, index_name: str, dataset_id: str, vector_size: int): + def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None): """ Create an index with given name """ diff --git a/common/doc_store/es_conn_base.py b/common/doc_store/es_conn_base.py index cec628c0d..3bbd8f7ca 100644 --- a/common/doc_store/es_conn_base.py +++ b/common/doc_store/es_conn_base.py @@ -123,7 +123,8 @@ class ESConnectionBase(DocStoreConnection): Table operations """ - def create_idx(self, index_name: str, dataset_id: str, vector_size: int): + def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None): + # parser_id is used by Infinity but not needed for ES (kept for interface compatibility) if self.index_exist(index_name, dataset_id): return True try: diff --git a/common/doc_store/infinity_conn_base.py b/common/doc_store/infinity_conn_base.py index c8679c31c..218f12552 100644 --- a/common/doc_store/infinity_conn_base.py +++ b/common/doc_store/infinity_conn_base.py @@ -228,15 +228,26 @@ class InfinityConnectionBase(DocStoreConnection): Table operations """ - def create_idx(self, index_name: str, dataset_id: str, vector_size: int): + def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None): table_name = f"{index_name}_{dataset_id}" + self.logger.debug(f"CREATE_IDX: Creating table {table_name}, parser_id: {parser_id}") + inf_conn = self.connPool.get_conn() inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) + # Use configured schema fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name) if not os.path.exists(fp_mapping): raise Exception(f"Mapping file not found at {fp_mapping}") schema = json.load(open(fp_mapping)) + + if parser_id is not None: + from common.constants import ParserType + if parser_id == ParserType.TABLE.value: + # Table parser: add chunk_data JSON column to store table-specific fields + schema["chunk_data"] = {"type": "json", "default": "{}"} + self.logger.info("Added chunk_data column for TABLE parser") + vector_name = f"q_{vector_size}_vec" schema[vector_name] = {"type": f"vector,{vector_size},float"} inf_table = inf_db.create_table( @@ -453,4 +464,198 @@ class InfinityConnectionBase(DocStoreConnection): """ def sql(self, sql: str, fetch_size: int, format: str): - raise NotImplementedError("Not implemented") + """ + Execute SQL query on Infinity database via psql command. + Transform text-to-sql for Infinity's SQL syntax. + """ + import subprocess + + try: + self.logger.debug(f"InfinityConnection.sql get sql: {sql}") + + # Clean up SQL + sql = re.sub(r"[ `]+", " ", sql) + sql = sql.replace("%", "") + + # Transform SELECT field aliases to actual stored field names + # Build field mapping from infinity_mapping.json comment field + field_mapping = {} + # Also build reverse mapping for column names in result + reverse_mapping = {} + fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name) + if os.path.exists(fp_mapping): + schema = json.load(open(fp_mapping)) + for field_name, field_info in schema.items(): + if "comment" in field_info: + # Parse comma-separated aliases from comment + # e.g., "docnm_kwd, title_tks, title_sm_tks" + aliases = [a.strip() for a in field_info["comment"].split(",")] + for alias in aliases: + field_mapping[alias] = field_name + reverse_mapping[field_name] = alias # Store first alias for reverse mapping + + # Replace field names in SELECT clause + select_match = re.search(r"(select\s+.*?)(from\s+)", sql, re.IGNORECASE) + if select_match: + select_clause = select_match.group(1) + from_clause = select_match.group(2) + + # Apply field transformations + for alias, actual in field_mapping.items(): + select_clause = re.sub( + rf'(^|[, ]){alias}([, ]|$)', + rf'\1{actual}\2', + select_clause + ) + + sql = select_clause + from_clause + sql[select_match.end():] + + # Also replace field names in WHERE, ORDER BY, GROUP BY, and HAVING clauses + for alias, actual in field_mapping.items(): + # Transform in WHERE clause + sql = re.sub( + rf'(\bwhere\s+[^;]*?)(\b){re.escape(alias)}\b', + rf'\1{actual}', + sql, + flags=re.IGNORECASE + ) + # Transform in ORDER BY clause + sql = re.sub( + rf'(\border by\s+[^;]*?)(\b){re.escape(alias)}\b', + rf'\1{actual}', + sql, + flags=re.IGNORECASE + ) + # Transform in GROUP BY clause + sql = re.sub( + rf'(\bgroup by\s+[^;]*?)(\b){re.escape(alias)}\b', + rf'\1{actual}', + sql, + flags=re.IGNORECASE + ) + # Transform in HAVING clause + sql = re.sub( + rf'(\bhaving\s+[^;]*?)(\b){re.escape(alias)}\b', + rf'\1{actual}', + sql, + flags=re.IGNORECASE + ) + + self.logger.debug(f"InfinityConnection.sql to execute: {sql}") + + # Get connection parameters from the Infinity connection pool wrapper + # We need to use INFINITY_CONN singleton, not the raw ConnectionPool + from common.doc_store.infinity_conn_pool import INFINITY_CONN + conn_info = INFINITY_CONN.get_conn_uri() + + # Parse host and port from conn_info + if conn_info and "host=" in conn_info: + host_match = re.search(r"host=(\S+)", conn_info) + if host_match: + host = host_match.group(1) + else: + host = "infinity" + else: + host = "infinity" + + # Parse port from conn_info, default to 5432 if not found + if conn_info and "port=" in conn_info: + port_match = re.search(r"port=(\d+)", conn_info) + if port_match: + port = port_match.group(1) + else: + port = "5432" + else: + port = "5432" + + # Use psql command to execute SQL + # Use full path to psql to avoid PATH issues + psql_path = "/usr/bin/psql" + # Check if psql exists at expected location, otherwise try to find it + import shutil + psql_from_path = shutil.which("psql") + if psql_from_path: + psql_path = psql_from_path + + # Execute SQL with psql to get both column names and data in one call + psql_cmd = [ + psql_path, + "-h", host, + "-p", port, + "-c", sql, + ] + + self.logger.debug(f"Executing psql command: {' '.join(psql_cmd)}") + + result = subprocess.run( + psql_cmd, + capture_output=True, + text=True, + timeout=10 # 10 second timeout + ) + + if result.returncode != 0: + error_msg = result.stderr.strip() + raise Exception(f"psql command failed: {error_msg}\nSQL: {sql}") + + # Parse the output + output = result.stdout.strip() + if not output: + # No results + return { + "columns": [], + "rows": [] + } if format == "json" else [] + + # Parse psql table output which has format: + # col1 | col2 | col3 + # -----+-----+----- + # val1 | val2 | val3 + lines = output.split("\n") + + # Extract column names from first line + columns = [] + rows = [] + + if len(lines) >= 1: + header_line = lines[0] + for col_name in header_line.split("|"): + col_name = col_name.strip() + if col_name: + columns.append({"name": col_name}) + + # Data starts after the separator line (line with dashes) + data_start = 2 if len(lines) >= 2 and "-" in lines[1] else 1 + for i in range(data_start, len(lines)): + line = lines[i].strip() + # Skip empty lines and footer lines like "(1 row)" + if not line or re.match(r"^\(\d+ row", line): + continue + # Split by | and strip each cell + row = [cell.strip() for cell in line.split("|")] + # Ensure row matches column count + if len(row) == len(columns): + rows.append(row) + elif len(row) > len(columns): + # Row has more cells than columns - truncate + rows.append(row[:len(columns)]) + elif len(row) < len(columns): + # Row has fewer cells - pad with empty strings + rows.append(row + [""] * (len(columns) - len(row))) + + if format == "json": + result = { + "columns": columns, + "rows": rows[:fetch_size] if fetch_size > 0 else rows + } + else: + result = rows[:fetch_size] if fetch_size > 0 else rows + + return result + + except subprocess.TimeoutExpired: + self.logger.exception(f"InfinityConnection.sql timeout. SQL:\n{sql}") + raise Exception(f"SQL timeout\n\nSQL: {sql}") + except Exception as e: + self.logger.exception(f"InfinityConnection.sql got exception. SQL:\n{sql}") + raise Exception(f"SQL error: {e}\n\nSQL: {sql}") diff --git a/common/doc_store/infinity_conn_pool.py b/common/doc_store/infinity_conn_pool.py index f74e24409..1aa3f8125 100644 --- a/common/doc_store/infinity_conn_pool.py +++ b/common/doc_store/infinity_conn_pool.py @@ -31,7 +31,11 @@ class InfinityConnectionPool: if hasattr(settings, "INFINITY"): self.INFINITY_CONFIG = settings.INFINITY else: - self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817"}) + self.INFINITY_CONFIG = settings.get_base_config("infinity", { + "uri": "infinity:23817", + "postgres_port": 5432, + "db_name": "default_db" + }) infinity_uri = self.INFINITY_CONFIG["uri"] if ":" in infinity_uri: @@ -61,6 +65,19 @@ class InfinityConnectionPool: def get_conn_pool(self): return self.conn_pool + def get_conn_uri(self): + """ + Get connection URI for PostgreSQL protocol. + """ + infinity_uri = self.INFINITY_CONFIG["uri"] + postgres_port = self.INFINITY_CONFIG["postgres_port"] + db_name = self.INFINITY_CONFIG["db_name"] + + if ":" in infinity_uri: + host, _ = infinity_uri.split(":") + return f"host={host} port={postgres_port} dbname={db_name}" + return f"host=localhost port={postgres_port} dbname={db_name}" + def refresh_conn_pool(self): try: inf_conn = self.conn_pool.get_conn() diff --git a/common/settings.py b/common/settings.py index efdd1fe36..83415c680 100644 --- a/common/settings.py +++ b/common/settings.py @@ -249,7 +249,11 @@ def init_settings(): ES = get_base_config("es", {}) docStoreConn = rag.utils.es_conn.ESConnection() elif lower_case_doc_engine == "infinity": - INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) + INFINITY = get_base_config("infinity", { + "uri": "infinity:23817", + "postgres_port": 5432, + "db_name": "default_db" + }) docStoreConn = rag.utils.infinity_conn.InfinityConnection() elif lower_case_doc_engine == "opensearch": OS = get_base_config("os", {}) @@ -269,7 +273,11 @@ def init_settings(): ES = get_base_config("es", {}) msgStoreConn = memory_es_conn.ESConnection() elif DOC_ENGINE == "infinity": - INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) + INFINITY = get_base_config("infinity", { + "uri": "infinity:23817", + "postgres_port": 5432, + "db_name": "default_db" + }) msgStoreConn = memory_infinity_conn.InfinityConnection() global AZURE, S3, MINIO, OSS, GCS diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index afd9b98bc..04a316488 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -29,6 +29,7 @@ os: password: 'infini_rag_flow_OS_01' infinity: uri: 'localhost:23817' + postgres_port: 5432 db_name: 'default_db' oceanbase: scheme: 'oceanbase' # set 'mysql' to create connection using mysql config diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index 6e08e962a..c03eaf2a9 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -29,6 +29,7 @@ os: password: '${OPENSEARCH_PASSWORD:-infini_rag_flow_OS_01}' infinity: uri: '${INFINITY_HOST:-infinity}:23817' + postgres_port: 5432 db_name: 'default_db' oceanbase: scheme: 'oceanbase' # set 'mysql' to create connection using mysql config diff --git a/rag/app/table.py b/rag/app/table.py index f931d2849..1b49994e5 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -33,6 +33,7 @@ from deepdoc.parser.figure_parser import vision_figure_parser_figure_xlsx_wrappe from deepdoc.parser.utils import get_text from rag.nlp import rag_tokenizer, tokenize, tokenize_table from deepdoc.parser import ExcelParser +from common import settings class Excel(ExcelParser): @@ -431,7 +432,9 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese res = [] PY = Pinyin() - fieds_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} + # Field type suffixes for database columns + # Maps data types to their database field suffixes + fields_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} for df in dfs: for n in ["id", "_id", "index", "idx"]: if n in df.columns: @@ -452,13 +455,24 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese df[clmns[j]] = cln if ty == "text": txts.extend([str(c) for c in cln if c]) - clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in + clmns_map = [(py_clmns[i].lower() + fields_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))] + # For Infinity: Use original column names as keys since they're stored in chunk_data JSON + # For ES/OS: Use full field names with type suffixes (e.g., url_kwd, body_tks) + if settings.DOC_ENGINE_INFINITY: + # For Infinity: key = original column name, value = display name + field_map = {py_clmns[i].lower(): str(clmns[i]).replace("_", " ") for i in range(len(clmns))} + else: + # For ES/OS: key = typed field name, value = display name + field_map = {k: v for k, v in clmns_map} + logging.debug(f"Field map: {field_map}") + KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map}) eng = lang.lower() == "english" # is_english(txts) for ii, row in df.iterrows(): d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} - row_txt = [] + row_fields = [] + data_json = {} # For Infinity: Store all columns in a JSON object for j in range(len(clmns)): if row[clmns[j]] is None: continue @@ -466,17 +480,27 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese continue if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]): continue - fld = clmns_map[j][0] - d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]]) - row_txt.append("{}:{}".format(clmns[j], row[clmns[j]])) - if not row_txt: + # For Infinity: Store in chunk_data JSON column + # For Elasticsearch/OpenSearch: Store as individual fields with type suffixes + if settings.DOC_ENGINE_INFINITY: + data_json[str(clmns[j])] = row[clmns[j]] + else: + fld = clmns_map[j][0] + d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]]) + row_fields.append((clmns[j], row[clmns[j]])) + if not row_fields: continue - tokenize(d, "; ".join(row_txt), eng) + # Add the data JSON field to the document (for Infinity only) + if settings.DOC_ENGINE_INFINITY: + d["chunk_data"] = data_json + # Format as a structured text for better LLM comprehension + # Format each field as "- Field Name: Value" on separate lines + formatted_text = "\n".join([f"- {field}: {value}" for field, value in row_fields]) + tokenize(d, formatted_text, eng) res.append(d) if tbls: doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} res.extend(tokenize_table(tbls, doc, is_english)) - KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}}) callback(0.35, "") return res diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 622da3834..cf2a37bea 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -558,7 +558,8 @@ def build_TOC(task, docs, progress_callback): def init_kb(row, vector_size: int): idxnm = search.index_name(row["tenant_id"]) - return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size) + parser_id = row.get("parser_id", None) + return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size, parser_id) async def embedding(docs, mdl, parser_config=None, callback=None): @@ -739,7 +740,7 @@ async def run_dataflow(task: dict): start_ts = timer() set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...") - e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) + e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) if not e: PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) @@ -833,7 +834,17 @@ async def delete_image(kb_id, chunk_id): raise -async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback): +async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback): + """ + Insert chunks into document store (Elasticsearch OR Infinity). + + Args: + task_id: Task identifier + task_tenant_id: Tenant ID + task_dataset_id: Dataset/knowledge base ID + chunks: List of chunk dictionaries to insert + progress_callback: Callback function for progress updates + """ mothers = [] mother_ids = set([]) for ck in chunks: @@ -858,7 +869,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c for b in range(0, len(mothers), settings.DOC_BULK_SIZE): await asyncio.to_thread(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE], - search.index_name(task_tenant_id), task_dataset_id, ) + search.index_name(task_tenant_id), task_dataset_id) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -866,7 +877,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c for b in range(0, len(chunks), settings.DOC_BULK_SIZE): doc_store_result = await asyncio.to_thread(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE], - search.index_name(task_tenant_id), task_dataset_id, ) + search.index_name(task_tenant_id), task_dataset_id) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -932,13 +943,6 @@ async def do_handle_task(task): # prepare the progress callback function progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) - # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user - lower_case_doc_engine = settings.DOC_ENGINE.lower() - if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table': - error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine." - progress_callback(-1, msg=error_message) - raise Exception(error_message) - task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -1092,14 +1096,14 @@ async def do_handle_task(task): chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer() - async def _maybe_insert_es(_chunks): + async def _maybe_insert_chunks(_chunks): if has_canceled(task_id): return True - insert_result = await insert_es(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback) + insert_result = await insert_chunks(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback) return bool(insert_result) try: - if not await _maybe_insert_es(chunks): + if not await _maybe_insert_chunks(chunks): return logging.info( @@ -1115,7 +1119,7 @@ async def do_handle_task(task): if toc_thread: d = toc_thread.result() if d: - if not await _maybe_insert_es([d]): + if not await _maybe_insert_chunks([d]): return DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0) diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index f65ae3eaf..916f919ee 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -317,7 +317,18 @@ class InfinityConnection(InfinityConnectionBase): break if vector_size == 0: raise ValueError("Cannot infer vector size from documents") - self.create_idx(index_name, knowledgebase_id, vector_size) + + # Determine parser_id from document structure + # Table parser documents have 'chunk_data' field + parser_id = None + if "chunk_data" in documents[0] and isinstance(documents[0].get("chunk_data"), dict): + from common.constants import ParserType + parser_id = ParserType.TABLE.value + self.logger.debug("Detected TABLE parser from document structure") + + # Fallback: Create table with base schema (shouldn't normally happen as init_kb() creates it) + self.logger.debug(f"Fallback: Creating table {table_name} with base schema, parser_id: {parser_id}") + self.create_idx(index_name, knowledgebase_id, vector_size, parser_id) table_instance = db_instance.get_table(table_name) # embedding fields can't have a default value.... @@ -378,6 +389,12 @@ class InfinityConnection(InfinityConnectionBase): d[k] = v elif re.search(r"_feas$", k): d[k] = json.dumps(v) + elif k == "chunk_data": + # Convert data dict to JSON string for storage + if isinstance(v, dict): + d[k] = json.dumps(v) + else: + d[k] = v elif k == "kb_id": if isinstance(d[k], list): d[k] = d[k][0] # since d[k] is a list, but we need a str @@ -586,6 +603,9 @@ class InfinityConnection(InfinityConnectionBase): res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd]) elif re.search(r"_feas$", k): res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {}) + elif k == "chunk_data": + # Parse JSON data back to dict for table parser fields + res2[column] = res2[column].apply(lambda v: json.loads(v) if v and isinstance(v, str) else v) elif k == "position_int": def to_position_int(v): if v: diff --git a/test/testcases/test_http_api/common.py b/test/testcases/test_http_api/common.py index 6810ca647..7e1d9927a 100644 --- a/test/testcases/test_http_api/common.py +++ b/test/testcases/test_http_api/common.py @@ -49,6 +49,11 @@ def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): + """ + Delete datasets. + The endpoint is DELETE /api/{VERSION}/datasets with payload {"ids": [...]} + This is the standard SDK REST API endpoint for dataset deletion. + """ res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data) return res.json() @@ -300,12 +305,6 @@ def metadata_summary(auth, dataset_id, params=None): # CHAT COMPLETIONS AND RELATED QUESTIONS -def chat_completions(auth, chat_assistant_id, payload=None): - url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}/completions" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) - return res.json() - - def related_questions(auth, payload=None): url = f"{HOST_ADDRESS}/api/{VERSION}/sessions/related_questions" res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) @@ -355,3 +354,23 @@ def agent_completions(auth, agent_id, payload=None): url = f"{HOST_ADDRESS}{AGENT_API_URL}/{agent_id}/completions" res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() + + +def chat_completions(auth, chat_id, payload=None): + """ + Send a question/message to a chat assistant and get completion. + + Args: + auth: Authentication object + chat_id: Chat assistant ID + payload: Dictionary containing: + - question: str (required) - The question to ask + - stream: bool (optional) - Whether to stream responses, default False + - session_id: str (optional) - Session ID for conversation context + + Returns: + Response JSON with answer data + """ + url = f"{HOST_ADDRESS}/api/{VERSION}/chats/{chat_id}/completions" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() diff --git a/test/testcases/test_http_api/test_chat_management/conftest.py b/test/testcases/test_http_api/test_chat_management/conftest.py new file mode 100644 index 000000000..cf64a5889 --- /dev/null +++ b/test/testcases/test_http_api/test_chat_management/conftest.py @@ -0,0 +1,42 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from common import create_dataset, delete_datasets + + +@pytest.fixture(scope="class") +def add_table_parser_dataset(HttpApiAuth, request): + """ + Fixture to create a table parser dataset for testing. + Automatically cleans up after tests complete (deletes dataset and table). + Note: field_map is automatically generated by the table parser when processing files. + """ + dataset_payload = { + "name": "test_table_parser_dataset", + "chunk_method": "table", # table parser + } + res = create_dataset(HttpApiAuth, dataset_payload) + assert res["code"] == 0, f"Failed to create dataset: {res}" + dataset_id = res["data"]["id"] + + def cleanup(): + delete_datasets(HttpApiAuth, {"ids": [dataset_id]}) + + request.addfinalizer(cleanup) + + return dataset_id diff --git a/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py b/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py new file mode 100644 index 000000000..b34a34f62 --- /dev/null +++ b/test/testcases/test_http_api/test_chat_management/test_table_parser_dataset_chat.py @@ -0,0 +1,324 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import re +import tempfile + +import pytest + +from common import ( + chat_completions, + create_chat_assistant, + create_session_with_chat_assistant, + delete_chat_assistants, + list_documents, + upload_documents, + parse_documents, +) +from utils import wait_for + +@wait_for(200, 1, "Document parsing timeout") +def wait_for_parsing_completion(auth, dataset_id, document_id=None): + """ + Wait for document parsing to complete. + + Args: + auth: Authentication object + dataset_id: Dataset ID + document_id: Optional specific document ID to wait for + + Returns: + bool: True if parsing is complete, False otherwise + """ + res = list_documents(auth, dataset_id) + docs = res["data"]["docs"] + + if document_id is None: + # Wait for all documents to complete + for doc in docs: + status = doc.get("run", "UNKNOWN") + if status != "DONE": + print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}") + return False + return True + else: + # Wait for specific document + for doc in docs: + if doc["id"] == document_id: + status = doc.get("run", "UNKNOWN") + print(f"[DEBUG] Document {doc.get('name', 'unknown')} status: {status}, progress: {doc.get('progress', 0)}%, msg: {doc.get('progress_msg', '')}") + if status == "DONE": + return True + elif status == "FAILED": + pytest.fail(f"Document parsing failed: {doc}") + return False + return False + +# Test data +TEST_EXCEL_DATA = [ + ["employee_id", "name", "department", "salary"], + ["E001", "Alice Johnson", "Engineering", "95000"], + ["E002", "Bob Smith", "Marketing", "65000"], + ["E003", "Carol Williams", "Engineering", "88000"], + ["E004", "David Brown", "Sales", "72000"], + ["E005", "Eva Davis", "HR", "68000"], + ["E006", "Frank Miller", "Engineering", "102000"], +] + +TEST_EXCEL_DATA_2 = [ + ["product", "price", "category"], + ["Laptop", "999", "Electronics"], + ["Mouse", "29", "Electronics"], + ["Desk", "299", "Furniture"], + ["Chair", "199", "Furniture"], + ["Monitor", "399", "Electronics"], + ["Keyboard", "79", "Electronics"], +] + +DEFAULT_CHAT_PROMPT = ( + "You are a helpful assistant that answers questions about table data using SQL queries.\n\n" + "Here is the knowledge base:\n{knowledge}\n\n" + "Use this information to answer questions." +) + + +@pytest.mark.usefixtures("add_table_parser_dataset") +class TestTableParserDatasetChat: + """ + Test table parser dataset chat functionality with Infinity backend. + + Verifies that: + 1. Excel files are uploaded and parsed correctly into table parser datasets + 2. Chat assistants can query the parsed table data via SQL + 3. Different types of queries work + """ + + @pytest.fixture(autouse=True) + def setup_chat_assistant(self, HttpApiAuth, add_table_parser_dataset, request): + """ + Setup fixture that runs before each test method. + Creates chat assistant once and reuses it across all test cases. + """ + # Only setup once (first time) + if not hasattr(self.__class__, 'chat_id'): + self.__class__.dataset_id = add_table_parser_dataset + self.__class__.auth = HttpApiAuth + + # Upload and parse Excel files once for all tests + self._upload_and_parse_excel(HttpApiAuth, add_table_parser_dataset) + + # Create a single chat assistant and session for all tests + chat_id, session_id = self._create_chat_assistant_with_session( + HttpApiAuth, add_table_parser_dataset + ) + self.__class__.chat_id = chat_id + self.__class__.session_id = session_id + + # Store the total number of parametrize cases + mark = request.node.get_closest_marker('parametrize') + if mark: + # Get the number of test cases from parametrize + param_values = mark.args[1] + self.__class__._total_tests = len(param_values) + else: + self.__class__._total_tests = 1 + + yield + + # Teardown: cleanup chat assistant after all tests + # Use a class-level counter to track tests + if not hasattr(self.__class__, '_test_counter'): + self.__class__._test_counter = 0 + self.__class__._test_counter += 1 + + # Cleanup after all parametrize tests complete + if self.__class__._test_counter >= self.__class__._total_tests: + self._teardown_chat_assistant() + + def _teardown_chat_assistant(self): + """Teardown method to clean up chat assistant.""" + if hasattr(self.__class__, 'chat_id') and self.__class__.chat_id: + try: + delete_chat_assistants(self.__class__.auth, {"ids": [self.__class__.chat_id]}) + except Exception as e: + print(f"[Teardown] Warning: Failed to delete chat assistant: {e}") + + @pytest.mark.p1 + @pytest.mark.parametrize( + "question, expected_answer_pattern", + [ + ("show me column of product", r"\|product\|Source"), + ("which product has price 79", r"Keyboard"), + ("How many rows in the dataset?", r"count\(\*\)"), + ("Show me all employees in Engineering department", r"(Alice|Carol|Frank)"), + ], + ) + def test_table_parser_dataset_chat(self, question, expected_answer_pattern): + """ + Test that table parser dataset chat works correctly. + """ + # Use class-level attributes (set by setup fixture) + answer = self._ask_question( + self.__class__.auth, + self.__class__.chat_id, + self.__class__.session_id, + question + ) + + # Verify answer matches expected pattern if provided + if expected_answer_pattern: + self._assert_answer_matches_pattern(answer, expected_answer_pattern) + else: + # Just verify we got a non-empty answer + assert answer and len(answer) > 0, "Expected non-empty answer" + + print(f"[Test] Question: {question}") + print(f"[Test] Answer: {answer[:100]}...") + + @staticmethod + def _upload_and_parse_excel(auth, dataset_id): + """ + Upload 2 Excel files and wait for parsing to complete. + + Returns: + list: The document IDs of the uploaded files + + Raises: + AssertionError: If upload or parsing fails + """ + excel_file_paths = [] + document_ids = [] + try: + # Create 2 temporary Excel files + excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA)) + excel_file_paths.append(TestTableParserDatasetChat._create_temp_excel_file(TEST_EXCEL_DATA_2)) + + # Upload documents + res = upload_documents(auth, dataset_id, excel_file_paths) + assert res["code"] == 0, f"Failed to upload documents: {res}" + + for doc in res["data"]: + document_ids.append(doc["id"]) + + # Start parsing for all documents + parse_payload = {"document_ids": document_ids} + res = parse_documents(auth, dataset_id, parse_payload) + assert res["code"] == 0, f"Failed to start parsing: {res}" + + # Wait for parsing completion for all documents + for doc_id in document_ids: + wait_for_parsing_completion(auth, dataset_id, doc_id) + + return document_ids + + finally: + # Clean up temporary files + for excel_file_path in excel_file_paths: + if excel_file_path: + os.unlink(excel_file_path) + + @staticmethod + def _create_temp_excel_file(data): + """ + Create a temporary Excel file with the given table test data. + + Args: + data: List of lists containing the Excel data + + Returns: + str: Path to the created temporary file + """ + from openpyxl import Workbook + + f = tempfile.NamedTemporaryFile(mode="wb", suffix=".xlsx", delete=False) + f.close() + + wb = Workbook() + ws = wb.active + + # Write test data to the worksheet + for row_idx, row_data in enumerate(data, start=1): + for col_idx, value in enumerate(row_data, start=1): + ws.cell(row=row_idx, column=col_idx, value=value) + + wb.save(f.name) + return f.name + + @staticmethod + def _create_chat_assistant_with_session(auth, dataset_id): + """ + Create a chat assistant and session for testing. + + Returns: + tuple: (chat_id, session_id) + """ + import uuid + + chat_payload = { + "name": f"test_table_parser_dataset_chat_{uuid.uuid4().hex[:8]}", + "dataset_ids": [dataset_id], + "prompt_config": { + "system": DEFAULT_CHAT_PROMPT, + "parameters": [ + { + "key": "knowledge", + "optional": True, + "value": "Use the table data to answer questions with SQL queries.", + } + ], + }, + } + + res = create_chat_assistant(auth, chat_payload) + assert res["code"] == 0, f"Failed to create chat assistant: {res}" + chat_id = res["data"]["id"] + + res = create_session_with_chat_assistant(auth, chat_id, {"name": f"test_session_{uuid.uuid4().hex[:8]}"}) + assert res["code"] == 0, f"Failed to create session: {res}" + session_id = res["data"]["id"] + + return chat_id, session_id + + def _ask_question(self, auth, chat_id, session_id, question): + """ + Send a question to the chat assistant and return the answer. + + Returns: + str: The assistant's answer + """ + payload = { + "question": question, + "stream": False, + "session_id": session_id, + } + + res_json = chat_completions(auth, chat_id, payload) + assert res_json["code"] == 0, f"Chat completion failed: {res_json}" + + return res_json["data"]["answer"] + + def _assert_answer_matches_pattern(self, answer, pattern): + """ + Assert that the answer matches the expected pattern. + + Args: + answer: The actual answer from the chat assistant + pattern: Regular expression pattern to match + """ + assert re.search(pattern, answer, re.IGNORECASE), ( + f"Answer does not match expected pattern '{pattern}'.\n" + f"Answer: {answer}" + )