mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add sdk for Agent API (#3220)
### What problem does this PR solve? Add sdk for Agent API ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
@ -7,4 +7,5 @@ from .modules.dataset import DataSet
|
||||
from .modules.chat import Chat
|
||||
from .modules.session import Session
|
||||
from .modules.document import Document
|
||||
from .modules.chunk import Chunk
|
||||
from .modules.chunk import Chunk
|
||||
from .modules.agent import Agent
|
||||
59
sdk/python/ragflow_sdk/modules/agent.py
Normal file
59
sdk/python/ragflow_sdk/modules/agent.py
Normal file
@ -0,0 +1,59 @@
|
||||
from .base import Base
|
||||
from .session import Session
|
||||
import requests
|
||||
|
||||
class Agent(Base):
|
||||
def __init__(self,rag,res_dict):
|
||||
self.id = None
|
||||
self.avatar = None
|
||||
self.canvas_type = None
|
||||
self.description = None
|
||||
self.dsl = None
|
||||
super().__init__(rag, res_dict)
|
||||
|
||||
class Dsl(Base):
|
||||
def __init__(self,rag,res_dict):
|
||||
self.answer = []
|
||||
self.components = {
|
||||
"begin": {
|
||||
"downstream": ["Answer:China"],
|
||||
"obj": {
|
||||
"component_name": "Begin",
|
||||
"params": {}
|
||||
},
|
||||
"upstream": []
|
||||
}
|
||||
}
|
||||
self.graph = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"label": "Begin",
|
||||
"name": "begin"
|
||||
},
|
||||
"id": "begin",
|
||||
"position": {
|
||||
"x": 50,
|
||||
"y": 200
|
||||
},
|
||||
"sourcePosition": "left",
|
||||
"targetPosition": "right",
|
||||
"type": "beginNode"
|
||||
}
|
||||
]
|
||||
}
|
||||
self.history = []
|
||||
self.messages = []
|
||||
self.path = []
|
||||
self.reference = []
|
||||
super().__init__(rag,res_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_session(id,rag) -> Session:
|
||||
res = requests.post(f"http://127.0.0.1:9380/api/v1/agents/{id}/sessions",headers={"Authorization": f"Bearer {rag.user_key}"},json={})
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
return Session(rag,res.get("data"))
|
||||
raise Exception(res.get("message"))
|
||||
|
||||
@ -9,14 +9,19 @@ class Session(Base):
|
||||
self.name = "New session"
|
||||
self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
|
||||
self.chat_id = None
|
||||
self.agent_id = None
|
||||
for key,value in res_dict.items():
|
||||
if key =="chat_id" and value is not None:
|
||||
self.__session_type = "chat"
|
||||
if key == "agent_id" and value is not None:
|
||||
self.__session_type = "agent"
|
||||
super().__init__(rag, res_dict)
|
||||
|
||||
def ask(self, question: str, stream: bool = False):
|
||||
for message in self.messages:
|
||||
if "reference" in message:
|
||||
message.pop("reference")
|
||||
res = self.post(f"/chats/{self.chat_id}/completions",
|
||||
{"question": question, "stream": True,"session_id":self.id}, stream=stream)
|
||||
def ask(self, question):
|
||||
if self.__session_type == "agent":
|
||||
res=self._ask_agent(question)
|
||||
elif self.__session_type == "chat":
|
||||
res=self._ask_chat(question)
|
||||
for line in res.iter_lines():
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("{"):
|
||||
@ -33,25 +38,20 @@ class Session(Base):
|
||||
}
|
||||
if "chunks" in reference:
|
||||
chunks = reference["chunks"]
|
||||
chunk_list = []
|
||||
for chunk in chunks:
|
||||
new_chunk = {
|
||||
"id": chunk["chunk_id"],
|
||||
"content": chunk["content_with_weight"],
|
||||
"document_id": chunk["doc_id"],
|
||||
"document_name": chunk["docnm_kwd"],
|
||||
"dataset_id": chunk["kb_id"],
|
||||
"image_id": chunk["img_id"],
|
||||
"similarity": chunk["similarity"],
|
||||
"vector_similarity": chunk["vector_similarity"],
|
||||
"term_similarity": chunk["term_similarity"],
|
||||
"positions": chunk["positions"],
|
||||
}
|
||||
chunk_list.append(new_chunk)
|
||||
temp_dict["reference"] = chunk_list
|
||||
temp_dict["reference"] = chunks
|
||||
message = Message(self.rag, temp_dict)
|
||||
yield message
|
||||
|
||||
|
||||
def _ask_chat(self, question: str, stream: bool = False):
|
||||
res = self.post(f"/chats/{self.chat_id}/completions",
|
||||
{"question": question, "stream": True,"session_id":self.id}, stream=stream)
|
||||
return res
|
||||
def _ask_agent(self,question:str,stream:bool=False):
|
||||
res = self.post(f"/agents/{self.agent_id}/completions",
|
||||
{"question": question, "stream": True,"session_id":self.id}, stream=stream)
|
||||
return res
|
||||
|
||||
def update(self,update_message):
|
||||
res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}",
|
||||
update_message)
|
||||
@ -66,20 +66,4 @@ class Message(Base):
|
||||
self.role = "assistant"
|
||||
self.prompt = None
|
||||
self.id = None
|
||||
super().__init__(rag, res_dict)
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
def __init__(self, rag, res_dict):
|
||||
self.id = None
|
||||
self.content = None
|
||||
self.document_id = ""
|
||||
self.document_name = ""
|
||||
self.dataset_id = ""
|
||||
self.image_id = ""
|
||||
self.similarity = None
|
||||
self.vector_similarity = None
|
||||
self.term_similarity = None
|
||||
self.positions = None
|
||||
super().__init__(rag, res_dict)
|
||||
|
||||
super().__init__(rag, res_dict)
|
||||
@ -1,5 +1,6 @@
|
||||
from ragflow_sdk import RAGFlow
|
||||
from ragflow_sdk import RAGFlow,Agent
|
||||
from common import HOST_ADDRESS
|
||||
import pytest
|
||||
|
||||
|
||||
def test_create_session_with_success(get_api_key_fixture):
|
||||
@ -58,6 +59,7 @@ def test_delete_sessions_with_success(get_api_key_fixture):
|
||||
session = assistant.create_session()
|
||||
assistant.delete_sessions(ids=[session.id])
|
||||
|
||||
|
||||
def test_update_session_with_name(get_api_key_fixture):
|
||||
API_KEY = get_api_key_fixture
|
||||
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
||||
@ -92,4 +94,17 @@ def test_list_sessions_with_success(get_api_key_fixture):
|
||||
assistant=rag.create_chat("test_list_session", dataset_ids=[kb.id])
|
||||
assistant.create_session("test_1")
|
||||
assistant.create_session("test_2")
|
||||
assistant.list_sessions()
|
||||
assistant.list_sessions()
|
||||
|
||||
@pytest.mark.skip(reason="")
|
||||
def test_create_agent_session_with_success(get_api_key_fixture):
|
||||
API_KEY = "ragflow-BkOGNhYjIyN2JiODExZWY5MzVhMDI0Mm"
|
||||
rag = RAGFlow(API_KEY,HOST_ADDRESS)
|
||||
Agent.create_session("2e45b5209c1011efa3e90242ac120006", rag)
|
||||
|
||||
@pytest.mark.skip(reason="")
|
||||
def test_create_agent_conversation_with_success(get_api_key_fixture):
|
||||
API_KEY = "ragflow-BkOGNhYjIyN2JiODExZWY5MzVhMDI0Mm"
|
||||
rag = RAGFlow(API_KEY,HOST_ADDRESS)
|
||||
session = Agent.create_session("2e45b5209c1011efa3e90242ac120006", rag)
|
||||
session.ask("What is this job")
|
||||
Reference in New Issue
Block a user