Fix: support cross language for API. (#8946)

### What problem does this PR solve?

Close #8943

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-07-21 17:25:28 +08:00
committed by GitHub
parent 7eb5ea3814
commit 0b487dee43
7 changed files with 17 additions and 4 deletions

View File

@ -38,7 +38,7 @@ from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_
from rag.app.qa import beAdoc, rmPrefix from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.prompts import keyword_extraction from rag.prompts import keyword_extraction, cross_languages
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
@ -1382,6 +1382,7 @@ def retrieval_test(tenant_id):
question = req["question"] question = req["question"]
doc_ids = req.get("document_ids", []) doc_ids = req.get("document_ids", [])
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
langs = req.get("cross_languages", [])
if not isinstance(doc_ids, list): if not isinstance(doc_ids, list):
return get_error_data_result("`documents` should be a list") return get_error_data_result("`documents` should be a list")
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids) doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
@ -1406,6 +1407,9 @@ def retrieval_test(tenant_id):
if req.get("rerank_id"): if req.get("rerank_id"):
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
if langs:
question = cross_languages(kb.tenant_id, None, question, langs)
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)

View File

@ -1725,6 +1725,7 @@ Retrieves chunks from specified datasets.
- `"rerank_id"`: `string` - `"rerank_id"`: `string`
- `"keyword"`: `boolean` - `"keyword"`: `boolean`
- `"highlight"`: `boolean` - `"highlight"`: `boolean`
- `"cross_languages"`: `list[string]`
##### Request example ##### Request example
@ -1769,6 +1770,8 @@ curl --request POST \
Specifies whether to enable highlighting of matched terms in the results: Specifies whether to enable highlighting of matched terms in the results:
- `true`: Enable highlighting of matched terms. - `true`: Enable highlighting of matched terms.
- `false`: Disable highlighting of matched terms (default). - `false`: Disable highlighting of matched terms (default).
- `"cross_languages"`: (*Body parameter*) `list[string]`
The languages that should be translated into, in order to achieve keywords retrievals in different languages.
#### Response #### Response

View File

@ -953,6 +953,10 @@ Specifies whether to enable highlighting of matched terms in the results:
- `True`: Enable highlighting of matched terms. - `True`: Enable highlighting of matched terms.
- `False`: Disable highlighting of matched terms (default). - `False`: Disable highlighting of matched terms (default).
##### cross_languages: `list[string]`
The languages that should be translated into, in order to achieve keywords retrievals in different languages.
#### Returns #### Returns
- Success: A list of `Chunk` objects representing the document chunks. - Success: A list of `Chunk` objects representing the document chunks.

View File

@ -250,5 +250,5 @@ class Extractor:
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)
logging.info(f"Trigger summary: {entity_or_relation_name}") logging.info(f"Trigger summary: {entity_or_relation_name}")
async with chat_limiter: async with chat_limiter:
summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})) summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}]))
return summary return summary

View File

@ -128,7 +128,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response}) history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT}) history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter: async with chat_limiter:
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8})) continuation = await trio.to_thread.run_sync(lambda: self._chat("", history))
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response) token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y": if continuation != "Y":
break break

View File

@ -86,7 +86,7 @@ class GraphExtractor(Extractor):
**self._context_base, input_text="{input_text}" **self._context_base, input_text="{input_text}"
).format(**self._context_base, input_text=content) ).format(**self._context_base, input_text=content)
gen_conf = {"temperature": 0.8} gen_conf = {}
async with chat_limiter: async with chat_limiter:
final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)) final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
token_count += num_tokens_from_string(hint_prompt + final_result) token_count += num_tokens_from_string(hint_prompt + final_result)

View File

@ -197,6 +197,7 @@ class RAGFlow:
top_k=1024, top_k=1024,
rerank_id: str | None = None, rerank_id: str | None = None,
keyword: bool = False, keyword: bool = False,
cross_languages: list[str]|None = None
): ):
if document_ids is None: if document_ids is None:
document_ids = [] document_ids = []
@ -211,6 +212,7 @@ class RAGFlow:
"question": question, "question": question,
"dataset_ids": dataset_ids, "dataset_ids": dataset_ids,
"document_ids": document_ids, "document_ids": document_ids,
"cross_languages": cross_languages
} }
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary) # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post("/retrieval", json=data_json) res = self.post("/retrieval", json=data_json)