mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-07 02:55:08 +08:00
Compare commits
5 Commits
0b7b88592f
...
2cb1046cbf
| Author | SHA1 | Date | |
|---|---|---|---|
| 2cb1046cbf | |||
| a880beb1f6 | |||
| 34283d4db4 | |||
| 5629fbd2ca | |||
| b7aa6d6c4f |
@ -632,7 +632,9 @@ class AdminCLI(Cmd):
|
|||||||
response = self.session.get(url)
|
response = self.session.get(url)
|
||||||
res_json = response.json()
|
res_json = response.json()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self._print_table_simple(res_json['data'])
|
table_data = res_json['data']
|
||||||
|
table_data.pop('avatar')
|
||||||
|
self._print_table_simple(table_data)
|
||||||
else:
|
else:
|
||||||
print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
@ -705,7 +707,10 @@ class AdminCLI(Cmd):
|
|||||||
response = self.session.get(url)
|
response = self.session.get(url)
|
||||||
res_json = response.json()
|
res_json = response.json()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self._print_table_simple(res_json['data'])
|
table_data = res_json['data']
|
||||||
|
for t in table_data:
|
||||||
|
t.pop('avatar')
|
||||||
|
self._print_table_simple(table_data)
|
||||||
else:
|
else:
|
||||||
print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
@ -717,7 +722,10 @@ class AdminCLI(Cmd):
|
|||||||
response = self.session.get(url)
|
response = self.session.get(url)
|
||||||
res_json = response.json()
|
res_json = response.json()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self._print_table_simple(res_json['data'])
|
table_data = res_json['data']
|
||||||
|
for t in table_data:
|
||||||
|
t.pop('avatar')
|
||||||
|
self._print_table_simple(table_data)
|
||||||
else:
|
else:
|
||||||
print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}")
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class UserMgr:
|
|||||||
result = []
|
result = []
|
||||||
for user in users:
|
for user in users:
|
||||||
result.append({
|
result.append({
|
||||||
|
'avatar': user.avatar,
|
||||||
'email': user.email,
|
'email': user.email,
|
||||||
'language': user.language,
|
'language': user.language,
|
||||||
'last_login_time': user.last_login_time,
|
'last_login_time': user.last_login_time,
|
||||||
@ -170,7 +171,8 @@ class UserServiceMgr:
|
|||||||
return [{
|
return [{
|
||||||
'title': r['title'],
|
'title': r['title'],
|
||||||
'permission': r['permission'],
|
'permission': r['permission'],
|
||||||
'canvas_category': r['canvas_category'].split('_')[0]
|
'canvas_category': r['canvas_category'].split('_')[0],
|
||||||
|
'avatar': r['avatar']
|
||||||
} for r in res]
|
} for r in res]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from flask import request
|
|||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from api.db import InputType
|
from api.db import InputType
|
||||||
from api.db.services.connector_service import ConnectorService, Connector2KbService, SyncLogsService
|
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||||
from api.utils.api_utils import get_json_result, validate_request, get_data_error_result
|
from api.utils.api_utils import get_json_result, validate_request, get_data_error_result
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.constants import RetCode, TaskStatus
|
from common.constants import RetCode, TaskStatus
|
||||||
@ -88,14 +88,14 @@ def resume(connector_id):
|
|||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/<connector_id>/link", methods=["POST"]) # noqa: F821
|
@manager.route("/<connector_id>/rebuild", methods=["PUT"]) # noqa: F821
|
||||||
@validate_request("kb_ids")
|
|
||||||
@login_required
|
@login_required
|
||||||
def link_kb(connector_id):
|
@validate_request("kb_id")
|
||||||
|
def rebuild(connector_id):
|
||||||
req = request.json
|
req = request.json
|
||||||
errors = Connector2KbService.link_kb(connector_id, req["kb_ids"], current_user.id)
|
err = ConnectorService.rebuild(connector_id, req["kb_id"], current_user.id)
|
||||||
if errors:
|
if err:
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -260,6 +260,8 @@ def list_docs():
|
|||||||
for doc_item in docs:
|
for doc_item in docs:
|
||||||
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
|
||||||
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
|
doc_item["thumbnail"] = f"/v1/document/image/{kb_id}-{doc_item['thumbnail']}"
|
||||||
|
if doc_item.get("source_type"):
|
||||||
|
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
|
||||||
|
|
||||||
return get_json_result(data={"total": tol, "docs": docs})
|
return get_json_result(data={"total": tol, "docs": docs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -1064,6 +1064,7 @@ class Connector2Kb(DataBaseModel):
|
|||||||
id = CharField(max_length=32, primary_key=True)
|
id = CharField(max_length=32, primary_key=True)
|
||||||
connector_id = CharField(max_length=32, null=False, index=True)
|
connector_id = CharField(max_length=32, null=False, index=True)
|
||||||
kb_id = CharField(max_length=32, null=False, index=True)
|
kb_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
auto_parse = CharField(max_length=1, null=False, default="1", index=False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "connector2kb"
|
db_table = "connector2kb"
|
||||||
@ -1282,4 +1283,8 @@ def migrate_db():
|
|||||||
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
migrate(migrator.add_column("tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -67,6 +67,7 @@ class UserCanvasService(CommonService):
|
|||||||
# will get all permitted agents, be cautious
|
# will get all permitted agents, be cautious
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
|
cls.model.avatar,
|
||||||
cls.model.title,
|
cls.model.title,
|
||||||
cls.model.permission,
|
cls.model.permission,
|
||||||
cls.model.canvas_type,
|
cls.model.canvas_type,
|
||||||
|
|||||||
@ -54,7 +54,6 @@ class ConnectorService(CommonService):
|
|||||||
SyncLogsService.update_by_id(task["id"], task)
|
SyncLogsService.update_by_id(task["id"], task)
|
||||||
ConnectorService.update_by_id(connector_id, {"status": status})
|
ConnectorService.update_by_id(connector_id, {"status": status})
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list(cls, tenant_id):
|
def list(cls, tenant_id):
|
||||||
fields = [
|
fields = [
|
||||||
@ -67,6 +66,15 @@ class ConnectorService(CommonService):
|
|||||||
cls.model.tenant_id == tenant_id
|
cls.model.tenant_id == tenant_id
|
||||||
).dicts())
|
).dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
|
||||||
|
e, conn = cls.get_by_id(connector_id)
|
||||||
|
if not e:
|
||||||
|
return
|
||||||
|
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
|
||||||
|
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||||
|
return FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||||
|
|
||||||
|
|
||||||
class SyncLogsService(CommonService):
|
class SyncLogsService(CommonService):
|
||||||
model = SyncLogs
|
model = SyncLogs
|
||||||
@ -91,6 +99,7 @@ class SyncLogsService(CommonService):
|
|||||||
Connector.timeout_secs,
|
Connector.timeout_secs,
|
||||||
Knowledgebase.name.alias("kb_name"),
|
Knowledgebase.name.alias("kb_name"),
|
||||||
Knowledgebase.avatar.alias("kb_avatar"),
|
Knowledgebase.avatar.alias("kb_avatar"),
|
||||||
|
Connector2Kb.auto_parse,
|
||||||
cls.model.from_beginning.alias("reindex"),
|
cls.model.from_beginning.alias("reindex"),
|
||||||
cls.model.status
|
cls.model.status
|
||||||
]
|
]
|
||||||
@ -179,7 +188,7 @@ class SyncLogsService(CommonService):
|
|||||||
.where(cls.model.id == id).execute()
|
.where(cls.model.id == id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def duplicate_and_parse(cls, kb, docs, tenant_id, src):
|
def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True):
|
||||||
if not docs:
|
if not docs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -191,14 +200,17 @@ class SyncLogsService(CommonService):
|
|||||||
return self.blob
|
return self.blob
|
||||||
|
|
||||||
errs = []
|
errs = []
|
||||||
files = [FileObj(filename=d["semantic_identifier"]+f".{d['extension']}", blob=d["blob"]) for d in docs]
|
files = [FileObj(filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs]
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src)
|
||||||
errs.extend(err)
|
errs.extend(err)
|
||||||
|
|
||||||
kb_table_num_map = {}
|
kb_table_num_map = {}
|
||||||
for doc, _ in doc_blob_pairs:
|
for doc, _ in doc_blob_pairs:
|
||||||
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
|
||||||
doc_ids.append(doc["id"])
|
doc_ids.append(doc["id"])
|
||||||
|
if not auto_parse or auto_parse == "0":
|
||||||
|
continue
|
||||||
|
DocumentService.run(tenant_id, doc, kb_table_num_map)
|
||||||
|
|
||||||
return errs, doc_ids
|
return errs, doc_ids
|
||||||
|
|
||||||
@ -213,33 +225,6 @@ class SyncLogsService(CommonService):
|
|||||||
class Connector2KbService(CommonService):
|
class Connector2KbService(CommonService):
|
||||||
model = Connector2Kb
|
model = Connector2Kb
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def link_kb(cls, conn_id:str, kb_ids: list[str], tenant_id:str):
|
|
||||||
arr = cls.query(connector_id=conn_id)
|
|
||||||
old_kb_ids = [a.kb_id for a in arr]
|
|
||||||
for kb_id in kb_ids:
|
|
||||||
if kb_id in old_kb_ids:
|
|
||||||
continue
|
|
||||||
cls.save(**{
|
|
||||||
"id": get_uuid(),
|
|
||||||
"connector_id": conn_id,
|
|
||||||
"kb_id": kb_id
|
|
||||||
})
|
|
||||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
|
||||||
|
|
||||||
errs = []
|
|
||||||
e, conn = ConnectorService.get_by_id(conn_id)
|
|
||||||
for kb_id in old_kb_ids:
|
|
||||||
if kb_id in kb_ids:
|
|
||||||
continue
|
|
||||||
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
|
||||||
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status==TaskStatus.SCHEDULE], {"status": TaskStatus.CANCEL})
|
|
||||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
|
||||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
|
||||||
if err:
|
|
||||||
errs.append(err)
|
|
||||||
return "\n".join(errs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def link_connectors(cls, kb_id:str, connector_ids: list[str], tenant_id:str):
|
def link_connectors(cls, kb_id:str, connector_ids: list[str], tenant_id:str):
|
||||||
arr = cls.query(kb_id=kb_id)
|
arr = cls.query(kb_id=kb_id)
|
||||||
@ -260,11 +245,15 @@ class Connector2KbService(CommonService):
|
|||||||
continue
|
continue
|
||||||
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
cls.filter_delete([cls.model.kb_id==kb_id, cls.model.connector_id==conn_id])
|
||||||
e, conn = ConnectorService.get_by_id(conn_id)
|
e, conn = ConnectorService.get_by_id(conn_id)
|
||||||
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status==TaskStatus.SCHEDULE], {"status": TaskStatus.CANCEL})
|
if not e:
|
||||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
continue
|
||||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
#SyncLogsService.filter_delete([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id])
|
||||||
if err:
|
# Do not delete docs while unlinking.
|
||||||
errs.append(err)
|
SyncLogsService.filter_update([SyncLogs.connector_id==conn_id, SyncLogs.kb_id==kb_id, SyncLogs.status.in_([TaskStatus.SCHEDULE, TaskStatus.RUNNING])], {"status": TaskStatus.CANCEL})
|
||||||
|
#docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}")
|
||||||
|
#err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||||
|
#if err:
|
||||||
|
# errs.append(err)
|
||||||
return "\n".join(errs)
|
return "\n".join(errs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -282,3 +271,5 @@ class Connector2KbService(CommonService):
|
|||||||
).dicts()
|
).dicts()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -201,6 +201,7 @@ class KnowledgebaseService(CommonService):
|
|||||||
# will get all permitted kb, be cautious.
|
# will get all permitted kb, be cautious.
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.name,
|
cls.model.name,
|
||||||
|
cls.model.avatar,
|
||||||
cls.model.language,
|
cls.model.language,
|
||||||
cls.model.permission,
|
cls.model.permission,
|
||||||
cls.model.doc_num,
|
cls.model.doc_num,
|
||||||
|
|||||||
@ -159,7 +159,7 @@ class PipelineOperationLogService(CommonService):
|
|||||||
document_name=document.name,
|
document_name=document.name,
|
||||||
document_suffix=document.suffix,
|
document_suffix=document.suffix,
|
||||||
document_type=document.type,
|
document_type=document.type,
|
||||||
source_from="", # TODO: add in the future
|
source_from=document.source_type.split("/")[0],
|
||||||
progress=document.progress,
|
progress=document.progress,
|
||||||
progress_msg=document.progress_msg,
|
progress_msg=document.progress_msg,
|
||||||
process_begin_at=document.process_begin_at,
|
process_begin_at=document.process_begin_at,
|
||||||
|
|||||||
@ -146,7 +146,10 @@ def get_redis_info():
|
|||||||
def check_ragflow_server_alive():
|
def check_ragflow_server_alive():
|
||||||
start_time = timer()
|
start_time = timer()
|
||||||
try:
|
try:
|
||||||
response = requests.get(f'http://{settings.HOST_IP}:{settings.HOST_PORT}/v1/system/ping')
|
url = f'http://{settings.HOST_IP}:{settings.HOST_PORT}/v1/system/ping'
|
||||||
|
if '0.0.0.0' in url:
|
||||||
|
url.replace('0.0.0.0', '127.0.0.1')
|
||||||
|
response = requests.get(url)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
|
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -253,6 +253,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
|||||||
all_child_page_ids: list[str] = []
|
all_child_page_ids: list[str] = []
|
||||||
|
|
||||||
for page in pages:
|
for page in pages:
|
||||||
|
if isinstance(page, dict):
|
||||||
|
page = NotionPage(**page)
|
||||||
if page.id in self.indexed_pages:
|
if page.id in self.indexed_pages:
|
||||||
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
logging.debug(f"Already indexed page with ID '{page.id}'. Skipping.")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -1840,7 +1840,7 @@ Retrieves chunks from specified datasets.
|
|||||||
- `"highlight"`: `boolean`
|
- `"highlight"`: `boolean`
|
||||||
- `"cross_languages"`: `list[string]`
|
- `"cross_languages"`: `list[string]`
|
||||||
- `"metadata_condition"`: `object`
|
- `"metadata_condition"`: `object`
|
||||||
|
- `"use_kg"`: `boolean`
|
||||||
##### Request example
|
##### Request example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -1888,6 +1888,8 @@ curl --request POST \
|
|||||||
The weight of vector cosine similarity. Defaults to `0.3`. If x represents the weight of vector cosine similarity, then (1 - x) is the term similarity weight.
|
The weight of vector cosine similarity. Defaults to `0.3`. If x represents the weight of vector cosine similarity, then (1 - x) is the term similarity weight.
|
||||||
- `"top_k"`: (*Body parameter*), `integer`
|
- `"top_k"`: (*Body parameter*), `integer`
|
||||||
The number of chunks engaged in vector cosine computation. Defaults to `1024`.
|
The number of chunks engaged in vector cosine computation. Defaults to `1024`.
|
||||||
|
- `"use_kg"`: (*Body parameter*), `boolean`
|
||||||
|
The search includes text chunks related to the knowledge graph of the selected dataset to handle complex multi-hop queries. Defaults to `False`.
|
||||||
- `"rerank_id"`: (*Body parameter*), `integer`
|
- `"rerank_id"`: (*Body parameter*), `integer`
|
||||||
The ID of the rerank model.
|
The ID of the rerank model.
|
||||||
- `"keyword"`: (*Body parameter*), `boolean`
|
- `"keyword"`: (*Body parameter*), `boolean`
|
||||||
|
|||||||
@ -724,8 +724,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tika import parser as tika_parser
|
||||||
|
except Exception as e:
|
||||||
|
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
|
||||||
|
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
|
||||||
|
return []
|
||||||
|
|
||||||
binary = BytesIO(binary)
|
binary = BytesIO(binary)
|
||||||
doc_parsed = parser.from_buffer(binary)
|
doc_parsed = tika_parser.from_buffer(binary)
|
||||||
if doc_parsed.get('content', None) is not None:
|
if doc_parsed.get('content', None) is not None:
|
||||||
sections = doc_parsed['content'].split('\n')
|
sections = doc_parsed['content'].split('\n')
|
||||||
sections = [(_, "") for _ in sections if _]
|
sections = [(_, "") for _ in sections if _]
|
||||||
|
|||||||
@ -384,7 +384,7 @@ class Dealer:
|
|||||||
rank_feature=rank_feature)
|
rank_feature=rank_feature)
|
||||||
else:
|
else:
|
||||||
lower_case_doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
|
lower_case_doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||||
if lower_case_doc_engine == "elasticsearch":
|
if lower_case_doc_engine in ["elasticsearch","opensearch"]:
|
||||||
# ElasticSearch doesn't normalize each way score before fusion.
|
# ElasticSearch doesn't normalize each way score before fusion.
|
||||||
sim, tsim, vsim = self.rerank(
|
sim, tsim, vsim = self.rerank(
|
||||||
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
||||||
|
|||||||
@ -78,7 +78,7 @@ class SyncBase:
|
|||||||
} for doc in document_batch]
|
} for doc in document_batch]
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}")
|
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
|
||||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||||
doc_num += len(docs)
|
doc_num += len(docs)
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import classNames from 'classnames';
|
|||||||
import { get } from 'lodash';
|
import { get } from 'lodash';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { NodeHandleId } from '../../constant';
|
import { NodeHandleId } from '../../constant';
|
||||||
import { useGetVariableLabelByValue } from '../../hooks/use-get-begin-query';
|
import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query';
|
||||||
import { CommonHandle, LeftEndHandle } from './handle';
|
import { CommonHandle, LeftEndHandle } from './handle';
|
||||||
import styles from './index.less';
|
import styles from './index.less';
|
||||||
import NodeHeader from './node-header';
|
import NodeHeader from './node-header';
|
||||||
@ -23,7 +23,7 @@ function InnerRetrievalNode({
|
|||||||
const knowledgeBaseIds: string[] = get(data, 'form.kb_ids', []);
|
const knowledgeBaseIds: string[] = get(data, 'form.kb_ids', []);
|
||||||
const { list: knowledgeList } = useFetchKnowledgeList(true);
|
const { list: knowledgeList } = useFetchKnowledgeList(true);
|
||||||
|
|
||||||
const getLabel = useGetVariableLabelByValue(id);
|
const { getLabel } = useGetVariableLabelOrTypeByValue(id);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ToolBar selected={selected} id={id} label={data.label}>
|
<ToolBar selected={selected} id={id} label={data.label}>
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import { LogicalOperatorIcon } from '@/hooks/logic-hooks/use-build-operator-opti
|
|||||||
import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow';
|
import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow';
|
||||||
import { NodeProps, Position } from '@xyflow/react';
|
import { NodeProps, Position } from '@xyflow/react';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetVariableLabelByValue } from '../../hooks/use-get-begin-query';
|
import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query';
|
||||||
import { CommonHandle, LeftEndHandle } from './handle';
|
import { CommonHandle, LeftEndHandle } from './handle';
|
||||||
import { RightHandleStyle } from './handle-icon';
|
import { RightHandleStyle } from './handle-icon';
|
||||||
import NodeHeader from './node-header';
|
import NodeHeader from './node-header';
|
||||||
@ -27,7 +27,7 @@ const ConditionBlock = ({
|
|||||||
nodeId,
|
nodeId,
|
||||||
}: { condition: ISwitchCondition } & { nodeId: string }) => {
|
}: { condition: ISwitchCondition } & { nodeId: string }) => {
|
||||||
const items = condition?.items ?? [];
|
const items = condition?.items ?? [];
|
||||||
const getLabel = useGetVariableLabelByValue(nodeId);
|
const { getLabel } = useGetVariableLabelOrTypeByValue(nodeId);
|
||||||
|
|
||||||
const renderOperatorIcon = useCallback((operator?: string) => {
|
const renderOperatorIcon = useCallback((operator?: string) => {
|
||||||
const item = SwitchOperatorOptions.find((x) => x.value === operator);
|
const item = SwitchOperatorOptions.find((x) => x.value === operator);
|
||||||
|
|||||||
@ -38,6 +38,7 @@ import TokenizerForm from '../form/tokenizer-form';
|
|||||||
import ToolForm from '../form/tool-form';
|
import ToolForm from '../form/tool-form';
|
||||||
import TuShareForm from '../form/tushare-form';
|
import TuShareForm from '../form/tushare-form';
|
||||||
import UserFillUpForm from '../form/user-fill-up-form';
|
import UserFillUpForm from '../form/user-fill-up-form';
|
||||||
|
import VariableAggregatorForm from '../form/variable-aggregator-form';
|
||||||
import VariableAssignerForm from '../form/variable-assigner-form';
|
import VariableAssignerForm from '../form/variable-assigner-form';
|
||||||
import WenCaiForm from '../form/wencai-form';
|
import WenCaiForm from '../form/wencai-form';
|
||||||
import WikipediaForm from '../form/wikipedia-form';
|
import WikipediaForm from '../form/wikipedia-form';
|
||||||
@ -186,4 +187,8 @@ export const FormConfigMap = {
|
|||||||
[Operator.VariableAssigner]: {
|
[Operator.VariableAssigner]: {
|
||||||
component: VariableAssignerForm,
|
component: VariableAssignerForm,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
[Operator.VariableAggregator]: {
|
||||||
|
component: VariableAggregatorForm,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@ -26,7 +26,6 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import {
|
import {
|
||||||
AgentExceptionMethod,
|
AgentExceptionMethod,
|
||||||
JsonSchemaDataType,
|
|
||||||
NodeHandleId,
|
NodeHandleId,
|
||||||
VariableType,
|
VariableType,
|
||||||
initialAgentValues,
|
initialAgentValues,
|
||||||
@ -158,7 +157,6 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
placeholder={t('flow.messagePlaceholder')}
|
placeholder={t('flow.messagePlaceholder')}
|
||||||
showToolbar={true}
|
showToolbar={true}
|
||||||
extraOptions={extraOptions}
|
extraOptions={extraOptions}
|
||||||
types={[JsonSchemaDataType.String]}
|
|
||||||
></PromptEditor>
|
></PromptEditor>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
@ -176,7 +174,6 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
<PromptEditor
|
<PromptEditor
|
||||||
{...field}
|
{...field}
|
||||||
showToolbar={true}
|
showToolbar={true}
|
||||||
types={[JsonSchemaDataType.String]}
|
|
||||||
></PromptEditor>
|
></PromptEditor>
|
||||||
</section>
|
</section>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
|||||||
@ -53,10 +53,13 @@ export function StructuredOutputSecondaryMenu({
|
|||||||
|
|
||||||
const renderAgentStructuredOutput = useCallback(
|
const renderAgentStructuredOutput = useCallback(
|
||||||
(values: any, option: { label: ReactNode; value: string }) => {
|
(values: any, option: { label: ReactNode; value: string }) => {
|
||||||
if (isPlainObject(values) && 'properties' in values) {
|
const properties =
|
||||||
|
get(values, 'properties') || get(values, 'items.properties');
|
||||||
|
|
||||||
|
if (isPlainObject(values) && properties) {
|
||||||
return (
|
return (
|
||||||
<ul className="border-l">
|
<ul className="border-l">
|
||||||
{Object.entries(values.properties).map(([key, value]) => {
|
{Object.entries(properties).map(([key, value]) => {
|
||||||
const nextOption = {
|
const nextOption = {
|
||||||
label: option.label + `.${key}`,
|
label: option.label + `.${key}`,
|
||||||
value: option.value + `.${key}`,
|
value: option.value + `.${key}`,
|
||||||
@ -79,8 +82,9 @@ export function StructuredOutputSecondaryMenu({
|
|||||||
{key}
|
{key}
|
||||||
<span className="text-text-secondary">{dataType}</span>
|
<span className="text-text-secondary">{dataType}</span>
|
||||||
</div>
|
</div>
|
||||||
{dataType === JsonSchemaDataType.Object &&
|
{[JsonSchemaDataType.Object, JsonSchemaDataType.Array].some(
|
||||||
renderAgentStructuredOutput(value, nextOption)}
|
(x) => x === dataType,
|
||||||
|
) && renderAgentStructuredOutput(value, nextOption)}
|
||||||
</li>
|
</li>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,7 +32,7 @@ import {
|
|||||||
} from '@/components/ui/tooltip';
|
} from '@/components/ui/tooltip';
|
||||||
import { cn } from '@/lib/utils';
|
import { cn } from '@/lib/utils';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetVariableLabelByValue } from '../../hooks/use-get-begin-query';
|
import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query';
|
||||||
import { VariableFormSchemaType } from './schema';
|
import { VariableFormSchemaType } from './schema';
|
||||||
|
|
||||||
interface IProps {
|
interface IProps {
|
||||||
@ -49,7 +49,7 @@ export function VariableTable({
|
|||||||
nodeId,
|
nodeId,
|
||||||
}: IProps) {
|
}: IProps) {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const getLabel = useGetVariableLabelByValue(nodeId!);
|
const { getLabel } = useGetVariableLabelOrTypeByValue(nodeId!);
|
||||||
|
|
||||||
const [sorting, setSorting] = React.useState<SortingState>([]);
|
const [sorting, setSorting] = React.useState<SortingState>([]);
|
||||||
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
||||||
|
|||||||
@ -14,7 +14,6 @@ import { memo } from 'react';
|
|||||||
import { useFieldArray, useForm } from 'react-hook-form';
|
import { useFieldArray, useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { JsonSchemaDataType } from '../../constant';
|
|
||||||
import { INextOperatorForm } from '../../interface';
|
import { INextOperatorForm } from '../../interface';
|
||||||
import { FormWrapper } from '../components/form-wrapper';
|
import { FormWrapper } from '../components/form-wrapper';
|
||||||
import { PromptEditor } from '../components/prompt-editor';
|
import { PromptEditor } from '../components/prompt-editor';
|
||||||
@ -63,11 +62,9 @@ function MessageForm({ node }: INextOperatorForm) {
|
|||||||
render={({ field }) => (
|
render={({ field }) => (
|
||||||
<FormItem className="flex-1">
|
<FormItem className="flex-1">
|
||||||
<FormControl>
|
<FormControl>
|
||||||
{/* <Textarea {...field}> </Textarea> */}
|
|
||||||
<PromptEditor
|
<PromptEditor
|
||||||
{...field}
|
{...field}
|
||||||
placeholder={t('flow.messagePlaceholder')}
|
placeholder={t('flow.messagePlaceholder')}
|
||||||
types={[JsonSchemaDataType.String]}
|
|
||||||
></PromptEditor>
|
></PromptEditor>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
|
|||||||
@ -0,0 +1,88 @@
|
|||||||
|
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import { Input } from '@/components/ui/input';
|
||||||
|
import { Plus, Trash2 } from 'lucide-react';
|
||||||
|
import { useFieldArray, useFormContext } from 'react-hook-form';
|
||||||
|
import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query';
|
||||||
|
import { QueryVariable } from '../components/query-variable';
|
||||||
|
|
||||||
|
type DynamicGroupVariableProps = {
|
||||||
|
name: string;
|
||||||
|
parentIndex: number;
|
||||||
|
removeParent: (index: number) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function DynamicGroupVariable({
|
||||||
|
name,
|
||||||
|
parentIndex,
|
||||||
|
removeParent,
|
||||||
|
}: DynamicGroupVariableProps) {
|
||||||
|
const form = useFormContext();
|
||||||
|
|
||||||
|
const variableFieldName = `${name}.variables`;
|
||||||
|
|
||||||
|
const { getType } = useGetVariableLabelOrTypeByValue();
|
||||||
|
|
||||||
|
const { fields, remove, append } = useFieldArray({
|
||||||
|
name: variableFieldName,
|
||||||
|
control: form.control,
|
||||||
|
});
|
||||||
|
|
||||||
|
const firstValue = form.getValues(`${variableFieldName}.0.value`);
|
||||||
|
const firstType = getType(firstValue);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<section className="py-3 group space-y-3">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<RAGFlowFormItem name={`${name}.group_name`} className="w-32">
|
||||||
|
<Input></Input>
|
||||||
|
</RAGFlowFormItem>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
variant={'ghost'}
|
||||||
|
type="button"
|
||||||
|
className="hidden group-hover:block"
|
||||||
|
onClick={() => removeParent(parentIndex)}
|
||||||
|
>
|
||||||
|
<Trash2 />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2 items-center">
|
||||||
|
{firstType && (
|
||||||
|
<span className="text-text-secondary border px-1 rounded-md">
|
||||||
|
{firstType}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
<Button
|
||||||
|
variant={'ghost'}
|
||||||
|
type="button"
|
||||||
|
onClick={() => append({ value: '' })}
|
||||||
|
>
|
||||||
|
<Plus />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<section className="space-y-3">
|
||||||
|
{fields.map((field, index) => (
|
||||||
|
<div key={field.id} className="flex gap-2 items-center">
|
||||||
|
<QueryVariable
|
||||||
|
name={`${variableFieldName}.${index}.value`}
|
||||||
|
className="flex-1 min-w-0"
|
||||||
|
hideLabel
|
||||||
|
types={firstType && fields.length > 1 ? [firstType] : []}
|
||||||
|
></QueryVariable>
|
||||||
|
<Button
|
||||||
|
variant={'ghost'}
|
||||||
|
type="button"
|
||||||
|
onClick={() => remove(index)}
|
||||||
|
>
|
||||||
|
<Trash2 />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</section>
|
||||||
|
</section>
|
||||||
|
);
|
||||||
|
}
|
||||||
81
web/src/pages/agent/form/variable-aggregator-form/index.tsx
Normal file
81
web/src/pages/agent/form/variable-aggregator-form/index.tsx
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import { BlockButton } from '@/components/ui/button';
|
||||||
|
import { Form } from '@/components/ui/form';
|
||||||
|
import { Separator } from '@/components/ui/separator';
|
||||||
|
import { zodResolver } from '@hookform/resolvers/zod';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useFieldArray, useForm } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { z } from 'zod';
|
||||||
|
import { initialDataOperationsValues } from '../../constant';
|
||||||
|
import { useFormValues } from '../../hooks/use-form-values';
|
||||||
|
import { useWatchFormChange } from '../../hooks/use-watch-form-change';
|
||||||
|
import { INextOperatorForm } from '../../interface';
|
||||||
|
import { buildOutputList } from '../../utils/build-output-list';
|
||||||
|
import { FormWrapper } from '../components/form-wrapper';
|
||||||
|
import { Output } from '../components/output';
|
||||||
|
import { DynamicGroupVariable } from './dynamic-group-variable';
|
||||||
|
|
||||||
|
export const RetrievalPartialSchema = {
|
||||||
|
groups: z.array(
|
||||||
|
z.object({
|
||||||
|
group_name: z.string(),
|
||||||
|
variables: z.array(z.object({ value: z.string().optional() })),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
operations: z.string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
export const FormSchema = z.object(RetrievalPartialSchema);
|
||||||
|
|
||||||
|
export type DataOperationsFormSchemaType = z.infer<typeof FormSchema>;
|
||||||
|
|
||||||
|
const outputList = buildOutputList(initialDataOperationsValues.outputs);
|
||||||
|
|
||||||
|
function VariableAggregatorForm({ node }: INextOperatorForm) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const defaultValues = useFormValues(initialDataOperationsValues, node);
|
||||||
|
|
||||||
|
const form = useForm<DataOperationsFormSchemaType>({
|
||||||
|
defaultValues: defaultValues,
|
||||||
|
mode: 'onChange',
|
||||||
|
resolver: zodResolver(FormSchema),
|
||||||
|
shouldUnregister: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { fields, remove, append } = useFieldArray({
|
||||||
|
name: 'groups',
|
||||||
|
control: form.control,
|
||||||
|
});
|
||||||
|
|
||||||
|
useWatchFormChange(node?.id, form, true);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Form {...form}>
|
||||||
|
<FormWrapper>
|
||||||
|
<section className="divide-y">
|
||||||
|
{fields.map((field, idx) => (
|
||||||
|
<DynamicGroupVariable
|
||||||
|
key={field.id}
|
||||||
|
name={`groups.${idx}`}
|
||||||
|
parentIndex={idx}
|
||||||
|
removeParent={remove}
|
||||||
|
></DynamicGroupVariable>
|
||||||
|
))}
|
||||||
|
</section>
|
||||||
|
<BlockButton
|
||||||
|
onClick={() =>
|
||||||
|
append({ group_name: `Group ${fields.length}`, variables: [] })
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{t('common.add')}
|
||||||
|
</BlockButton>
|
||||||
|
<Separator />
|
||||||
|
|
||||||
|
<Output list={outputList} isFormRequired></Output>
|
||||||
|
</FormWrapper>
|
||||||
|
</Form>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(VariableAggregatorForm);
|
||||||
@ -1,6 +1,10 @@
|
|||||||
import { get } from 'lodash';
|
import { get, isPlainObject } from 'lodash';
|
||||||
import { ReactNode, useCallback } from 'react';
|
import { ReactNode, useCallback } from 'react';
|
||||||
import { AgentStructuredOutputField, Operator } from '../constant';
|
import {
|
||||||
|
AgentStructuredOutputField,
|
||||||
|
JsonSchemaDataType,
|
||||||
|
Operator,
|
||||||
|
} from '../constant';
|
||||||
import useGraphStore from '../store';
|
import useGraphStore from '../store';
|
||||||
|
|
||||||
function getNodeId(value: string) {
|
function getNodeId(value: string) {
|
||||||
@ -82,3 +86,70 @@ export function useFindAgentStructuredOutputLabel() {
|
|||||||
|
|
||||||
return findAgentStructuredOutputLabel;
|
return findAgentStructuredOutputLabel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useFindAgentStructuredOutputTypeByValue() {
|
||||||
|
const { getOperatorTypeFromId } = useGraphStore((state) => state);
|
||||||
|
const filterStructuredOutput = useGetStructuredOutputByValue();
|
||||||
|
|
||||||
|
const findTypeByValue = useCallback(
|
||||||
|
(
|
||||||
|
values: unknown,
|
||||||
|
target: string,
|
||||||
|
path: string = '',
|
||||||
|
): string | undefined => {
|
||||||
|
const properties =
|
||||||
|
get(values, 'properties') || get(values, 'items.properties');
|
||||||
|
|
||||||
|
if (isPlainObject(values) && properties) {
|
||||||
|
for (const [key, value] of Object.entries(properties)) {
|
||||||
|
const nextPath = path ? `${path}.${key}` : key;
|
||||||
|
const dataType = get(value, 'type');
|
||||||
|
|
||||||
|
if (nextPath === target) {
|
||||||
|
return dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
[JsonSchemaDataType.Object, JsonSchemaDataType.Array].some(
|
||||||
|
(x) => x === dataType,
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
const type = findTypeByValue(value, target, nextPath);
|
||||||
|
if (type) {
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
|
||||||
|
const findAgentStructuredOutputTypeByValue = useCallback(
|
||||||
|
(value?: string) => {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const fields = value.split('@');
|
||||||
|
const nodeId = fields.at(0);
|
||||||
|
const jsonSchema = filterStructuredOutput(value);
|
||||||
|
|
||||||
|
if (
|
||||||
|
getOperatorTypeFromId(nodeId) === Operator.Agent &&
|
||||||
|
fields.at(1)?.startsWith(AgentStructuredOutputField)
|
||||||
|
) {
|
||||||
|
const jsonSchemaFields = fields
|
||||||
|
.at(1)
|
||||||
|
?.slice(AgentStructuredOutputField.length + 1);
|
||||||
|
|
||||||
|
if (jsonSchemaFields) {
|
||||||
|
const type = findTypeByValue(jsonSchema, jsonSchemaFields);
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[filterStructuredOutput, findTypeByValue, getOperatorTypeFromId],
|
||||||
|
);
|
||||||
|
|
||||||
|
return findAgentStructuredOutputTypeByValue;
|
||||||
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import { buildBeginInputListFromObject } from '../form/begin-form/utils';
|
|||||||
import { BeginQuery } from '../interface';
|
import { BeginQuery } from '../interface';
|
||||||
import OperatorIcon from '../operator-icon';
|
import OperatorIcon from '../operator-icon';
|
||||||
import useGraphStore from '../store';
|
import useGraphStore from '../store';
|
||||||
|
import { useFindAgentStructuredOutputTypeByValue } from './use-build-structured-output';
|
||||||
|
|
||||||
export function useSelectBeginNodeDataInputs() {
|
export function useSelectBeginNodeDataInputs() {
|
||||||
const getNode = useGraphStore((state) => state.getNode);
|
const getNode = useGraphStore((state) => state.getNode);
|
||||||
@ -263,7 +264,7 @@ export const useGetComponentLabelByValue = (nodeId: string) => {
|
|||||||
return getLabel;
|
return getLabel;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useGetVariableLabelByValue(nodeId: string) {
|
export function useFlattenQueryVariableOptions(nodeId?: string) {
|
||||||
const { getNode } = useGraphStore((state) => state);
|
const { getNode } = useGraphStore((state) => state);
|
||||||
const nextOptions = useBuildQueryVariableOptions(getNode(nodeId));
|
const nextOptions = useBuildQueryVariableOptions(getNode(nodeId));
|
||||||
|
|
||||||
@ -273,11 +274,34 @@ export function useGetVariableLabelByValue(nodeId: string) {
|
|||||||
}, []);
|
}, []);
|
||||||
}, [nextOptions]);
|
}, [nextOptions]);
|
||||||
|
|
||||||
const getLabel = useCallback(
|
return flattenOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useGetVariableLabelOrTypeByValue(nodeId?: string) {
|
||||||
|
const flattenOptions = useFlattenQueryVariableOptions(nodeId);
|
||||||
|
const findAgentStructuredOutputTypeByValue =
|
||||||
|
useFindAgentStructuredOutputTypeByValue();
|
||||||
|
|
||||||
|
const getItem = useCallback(
|
||||||
(val?: string) => {
|
(val?: string) => {
|
||||||
return flattenOptions.find((x) => x.value === val)?.label;
|
return flattenOptions.find((x) => x.value === val);
|
||||||
},
|
},
|
||||||
[flattenOptions],
|
[flattenOptions],
|
||||||
);
|
);
|
||||||
return getLabel;
|
|
||||||
|
const getLabel = useCallback(
|
||||||
|
(val?: string) => {
|
||||||
|
return getItem(val)?.label;
|
||||||
|
},
|
||||||
|
[getItem],
|
||||||
|
);
|
||||||
|
|
||||||
|
const getType = useCallback(
|
||||||
|
(val?: string) => {
|
||||||
|
return getItem(val)?.type || findAgentStructuredOutputTypeByValue(val);
|
||||||
|
},
|
||||||
|
[findAgentStructuredOutputTypeByValue, getItem],
|
||||||
|
);
|
||||||
|
|
||||||
|
return { getLabel, getType };
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user