mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add meta data filter. (#9405)
### What problem does this PR solve? #8531 #7417 #6761 #6573 #6477 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -51,6 +51,7 @@ def set_dialog():
|
|||||||
similarity_threshold = req.get("similarity_threshold", 0.1)
|
similarity_threshold = req.get("similarity_threshold", 0.1)
|
||||||
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
|
||||||
llm_setting = req.get("llm_setting", {})
|
llm_setting = req.get("llm_setting", {})
|
||||||
|
meta_data_filter = req.get("meta_data_filter", {})
|
||||||
prompt_config = req["prompt_config"]
|
prompt_config = req["prompt_config"]
|
||||||
|
|
||||||
if not is_create:
|
if not is_create:
|
||||||
@ -85,6 +86,7 @@ def set_dialog():
|
|||||||
"llm_id": llm_id,
|
"llm_id": llm_id,
|
||||||
"llm_setting": llm_setting,
|
"llm_setting": llm_setting,
|
||||||
"prompt_config": prompt_config,
|
"prompt_config": prompt_config,
|
||||||
|
"meta_data_filter": meta_data_filter,
|
||||||
"top_n": top_n,
|
"top_n": top_n,
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"rerank_id": rerank_id,
|
"rerank_id": rerank_id,
|
||||||
|
|||||||
@ -681,6 +681,11 @@ def set_meta():
|
|||||||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
try:
|
try:
|
||||||
meta = json.loads(req["meta"])
|
meta = json.loads(req["meta"])
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
return get_json_result(data=False, message="Only dictionary type supported.", code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
for k,v in meta.items():
|
||||||
|
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float):
|
||||||
|
return get_json_result(data=False, message=f"The type is not supported: {v}", code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR)
|
return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
if not isinstance(meta, dict):
|
if not isinstance(meta, dict):
|
||||||
|
|||||||
@ -351,6 +351,7 @@ def knowledge_graph(kb_id):
|
|||||||
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/<kb_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
|
@manager.route('/<kb_id>/knowledge_graph', methods=['DELETE']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def delete_knowledge_graph(kb_id):
|
def delete_knowledge_graph(kb_id):
|
||||||
@ -364,3 +365,17 @@ def delete_knowledge_graph(kb_id):
|
|||||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/get_meta", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def get_meta():
|
||||||
|
kb_ids = request.args.get("kb_ids", "").split(",")
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||||
|
return get_json_result(
|
||||||
|
data=False,
|
||||||
|
message='No authorization.',
|
||||||
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
|
)
|
||||||
|
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids))
|
||||||
|
|||||||
@ -744,6 +744,7 @@ class Dialog(DataBaseModel):
|
|||||||
null=False,
|
null=False,
|
||||||
default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?", "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"},
|
default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?", "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"},
|
||||||
)
|
)
|
||||||
|
meta_data_filter = JSONField(null=True, default={})
|
||||||
|
|
||||||
similarity_threshold = FloatField(default=0.2)
|
similarity_threshold = FloatField(default=0.2)
|
||||||
vector_similarity_weight = FloatField(default=0.3)
|
vector_similarity_weight = FloatField(default=0.3)
|
||||||
@ -1015,4 +1016,8 @@ def migrate_db():
|
|||||||
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
|
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
@ -30,6 +30,7 @@ from api import settings
|
|||||||
from api.db import LLMType, ParserType, StatusEnum
|
from api.db import LLMType, ParserType, StatusEnum
|
||||||
from api.db.db_models import DB, Dialog
|
from api.db.db_models import DB, Dialog
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||||
@ -38,6 +39,7 @@ from rag.app.resume import forbidden_select_fields4resume
|
|||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
||||||
|
from rag.prompts.prompts import gen_meta_filter
|
||||||
from rag.utils import num_tokens_from_string, rmSpace
|
from rag.utils import num_tokens_from_string, rmSpace
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
@ -250,6 +252,46 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|||||||
return answer, idx
|
return answer, idx
|
||||||
|
|
||||||
|
|
||||||
|
def meta_filter(metas: dict, filters: list[dict]):
|
||||||
|
doc_ids = []
|
||||||
|
def filter_out(v2docs, operator, value):
|
||||||
|
nonlocal doc_ids
|
||||||
|
for input,docids in v2docs.items():
|
||||||
|
try:
|
||||||
|
input = float(input)
|
||||||
|
value = float(value)
|
||||||
|
except Exception:
|
||||||
|
input = str(input)
|
||||||
|
value = str(value)
|
||||||
|
|
||||||
|
for conds in [
|
||||||
|
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||||
|
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||||
|
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||||
|
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||||
|
(operator == "empty", not input),
|
||||||
|
(operator == "not empty", input),
|
||||||
|
(operator == "=", input == value),
|
||||||
|
(operator == "≠", input != value),
|
||||||
|
(operator == ">", input > value),
|
||||||
|
(operator == "<", input < value),
|
||||||
|
(operator == "≥", input >= value),
|
||||||
|
(operator == "≤", input <= value),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
if all(conds):
|
||||||
|
doc_ids.extend(docids)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for k, v2docs in metas.items():
|
||||||
|
for f in filters:
|
||||||
|
if k != f["key"]:
|
||||||
|
continue
|
||||||
|
filter_out(v2docs, f["op"], f["value"])
|
||||||
|
return doc_ids
|
||||||
|
|
||||||
|
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
@ -287,9 +329,10 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
retriever = settings.retrievaler
|
retriever = settings.retrievaler
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||||
if "doc_ids" in messages[-1]:
|
if "doc_ids" in messages[-1]:
|
||||||
attachments = messages[-1]["doc_ids"]
|
attachments = messages[-1]["doc_ids"]
|
||||||
|
|
||||||
prompt_config = dialog.prompt_config
|
prompt_config = dialog.prompt_config
|
||||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||||
# try to use sql if field mapping is good to go
|
# try to use sql if field mapping is good to go
|
||||||
@ -316,6 +359,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if prompt_config.get("cross_languages"):
|
if prompt_config.get("cross_languages"):
|
||||||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||||
|
|
||||||
|
if dialog.meta_data_filter:
|
||||||
|
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||||
|
if dialog.meta_data_filter.get("method") == "auto":
|
||||||
|
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||||
|
attachments.extend(meta_filter(metas, filters))
|
||||||
|
elif dialog.meta_data_filter.get("method") == "manual":
|
||||||
|
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
||||||
|
|
||||||
if prompt_config.get("keyword", False):
|
if prompt_config.get("keyword", False):
|
||||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||||
|
|
||||||
|
|||||||
@ -574,6 +574,25 @@ class DocumentService(CommonService):
|
|||||||
def update_meta_fields(cls, doc_id, meta_fields):
|
def update_meta_fields(cls, doc_id, meta_fields):
|
||||||
return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
|
return cls.update_by_id(doc_id, {"meta_fields": meta_fields})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_meta_by_kbs(cls, kb_ids):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.meta_fields,
|
||||||
|
]
|
||||||
|
meta = {}
|
||||||
|
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)):
|
||||||
|
doc_id = r.id
|
||||||
|
for k,v in r.meta_fields.items():
|
||||||
|
if k not in meta:
|
||||||
|
meta[k] = {}
|
||||||
|
v = str(v)
|
||||||
|
if v not in meta[k]:
|
||||||
|
meta[k][v] = []
|
||||||
|
meta[k][v].append(doc_id)
|
||||||
|
return meta
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_progress(cls):
|
def update_progress(cls):
|
||||||
|
|||||||
@ -383,8 +383,6 @@ class Dealer:
|
|||||||
vector_column = f"q_{dim}_vec"
|
vector_column = f"q_{dim}_vec"
|
||||||
zero_vector = [0.0] * dim
|
zero_vector = [0.0] * dim
|
||||||
sim_np = np.array(sim)
|
sim_np = np.array(sim)
|
||||||
if doc_ids:
|
|
||||||
similarity_threshold = 0
|
|
||||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||||
for i in idx:
|
for i in idx:
|
||||||
|
|||||||
53
rag/prompts/meta_filter.md
Normal file
53
rag/prompts/meta_filter.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules:
|
||||||
|
|
||||||
|
1. **Metadata Structure**:
|
||||||
|
- Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs.
|
||||||
|
- Example:
|
||||||
|
{
|
||||||
|
"color": {"red": ["doc1"], "blue": ["doc2"]},
|
||||||
|
"listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
2. **Output Requirements**:
|
||||||
|
- Always output a JSON array of filter objects
|
||||||
|
- Each object must have:
|
||||||
|
"key": (metadata attribute name),
|
||||||
|
"value": (string value to compare),
|
||||||
|
"op": (operator from allowed list)
|
||||||
|
|
||||||
|
3. **Operator Guide**:
|
||||||
|
- Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"]
|
||||||
|
- Date ranges: Break into two conditions (≥ start_date AND < next_month_start)
|
||||||
|
- Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠")
|
||||||
|
- Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01])
|
||||||
|
|
||||||
|
4. **Processing Steps**:
|
||||||
|
a) Identify ALL filterable attributes in the query (both explicit and implicit)
|
||||||
|
b) For dates:
|
||||||
|
- Infer missing year from current date if needed
|
||||||
|
- Always format dates as "YYYY-MM-DD"
|
||||||
|
- Convert ranges: [≥ start, < end]
|
||||||
|
c) For values: Match EXACTLY to metadata's value keys
|
||||||
|
d) Skip conditions if:
|
||||||
|
- Attribute doesn't exist in metadata
|
||||||
|
- Value has no match in metadata
|
||||||
|
|
||||||
|
5. **Example**:
|
||||||
|
- User query: "上市日期七月份的有哪些商品,不要蓝色的"
|
||||||
|
- Metadata: { "color": {...}, "listing_date": {...} }
|
||||||
|
- Output:
|
||||||
|
[
|
||||||
|
{"key": "listing_date", "value": "2025-07-01", "op": "≥"},
|
||||||
|
{"key": "listing_date", "value": "2025-08-01", "op": "<"},
|
||||||
|
{"key": "color", "value": "blue", "op": "≠"}
|
||||||
|
]
|
||||||
|
|
||||||
|
6. **Final Output**:
|
||||||
|
- ONLY output valid JSON array
|
||||||
|
- NO additional text/explanations
|
||||||
|
|
||||||
|
**Current Task**:
|
||||||
|
- Today's date: {{current_date}}
|
||||||
|
- Available metadata keys: {{metadata_keys}}
|
||||||
|
- User query: "{{user_question}}"
|
||||||
|
|
||||||
@ -149,6 +149,7 @@ NEXT_STEP = load_prompt("next_step")
|
|||||||
REFLECT = load_prompt("reflect")
|
REFLECT = load_prompt("reflect")
|
||||||
SUMMARY4MEMORY = load_prompt("summary4memory")
|
SUMMARY4MEMORY = load_prompt("summary4memory")
|
||||||
RANK_MEMORY = load_prompt("rank_memory")
|
RANK_MEMORY = load_prompt("rank_memory")
|
||||||
|
META_FILTER = load_prompt("meta_filter")
|
||||||
|
|
||||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||||
|
|
||||||
@ -413,3 +414,20 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st
|
|||||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
|
||||||
|
sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render(
|
||||||
|
current_date=datetime.datetime.today().strftime('%Y-%m-%d'),
|
||||||
|
metadata_keys=json.dumps(meta_data),
|
||||||
|
user_question=query
|
||||||
|
)
|
||||||
|
user_prompt = "Generate filters:"
|
||||||
|
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
|
||||||
|
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||||
|
try:
|
||||||
|
ans = json_repair.loads(ans)
|
||||||
|
assert isinstance(ans, list), ans
|
||||||
|
return ans
|
||||||
|
except Exception:
|
||||||
|
logging.exception(f"Loading json failure: {ans}")
|
||||||
|
return []
|
||||||
@ -444,7 +444,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
|
tts = np.concatenate([vts for _ in range(len(tts))], axis=0)
|
||||||
tk_count += c
|
tk_count += c
|
||||||
|
|
||||||
@timeout(5)
|
@timeout(60)
|
||||||
def batch_encode(txts):
|
def batch_encode(txts):
|
||||||
nonlocal mdl
|
nonlocal mdl
|
||||||
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
|
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts])
|
||||||
|
|||||||
@ -190,3 +190,17 @@ class RAGFlowS3:
|
|||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@use_prefix_path
|
||||||
|
@use_default_bucket
|
||||||
|
def rm_bucket(self, bucket, *args, **kwargs):
|
||||||
|
for conn in self.conn:
|
||||||
|
try:
|
||||||
|
if not conn.bucket_exists(bucket):
|
||||||
|
continue
|
||||||
|
for o in conn.list_objects_v2(Bucket=bucket):
|
||||||
|
conn.delete_object(bucket, o.object_name)
|
||||||
|
conn.delete_bucket(Bucket=bucket)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Fail rm {bucket}: " + str(e))
|
||||||
|
|||||||
Reference in New Issue
Block a user