Add sdk for Agent API (#3220)

### What problem does this PR solve?

Add sdk for Agent API

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua
2024-11-06 18:03:45 +08:00
committed by GitHub
parent 0dff64f6ad
commit f3aaa0d453
7 changed files with 546 additions and 92 deletions

View File

@ -7,4 +7,5 @@ from .modules.dataset import DataSet
from .modules.chat import Chat
from .modules.session import Session
from .modules.document import Document
from .modules.chunk import Chunk
from .modules.chunk import Chunk
from .modules.agent import Agent

View File

@ -0,0 +1,59 @@
from .base import Base
from .session import Session
import requests
class Agent(Base):
def __init__(self,rag,res_dict):
self.id = None
self.avatar = None
self.canvas_type = None
self.description = None
self.dsl = None
super().__init__(rag, res_dict)
class Dsl(Base):
def __init__(self,rag,res_dict):
self.answer = []
self.components = {
"begin": {
"downstream": ["Answer:China"],
"obj": {
"component_name": "Begin",
"params": {}
},
"upstream": []
}
}
self.graph = {
"edges": [],
"nodes": [
{
"data": {
"label": "Begin",
"name": "begin"
},
"id": "begin",
"position": {
"x": 50,
"y": 200
},
"sourcePosition": "left",
"targetPosition": "right",
"type": "beginNode"
}
]
}
self.history = []
self.messages = []
self.path = []
self.reference = []
super().__init__(rag,res_dict)
@staticmethod
def create_session(id,rag) -> Session:
res = requests.post(f"http://127.0.0.1:9380/api/v1/agents/{id}/sessions",headers={"Authorization": f"Bearer {rag.user_key}"},json={})
res = res.json()
if res.get("code") == 0:
return Session(rag,res.get("data"))
raise Exception(res.get("message"))

View File

@ -9,14 +9,19 @@ class Session(Base):
self.name = "New session"
self.messages = [{"role": "assistant", "content": "Hi! I am your assistantcan I help you?"}]
self.chat_id = None
self.agent_id = None
for key,value in res_dict.items():
if key =="chat_id" and value is not None:
self.__session_type = "chat"
if key == "agent_id" and value is not None:
self.__session_type = "agent"
super().__init__(rag, res_dict)
def ask(self, question: str, stream: bool = False):
for message in self.messages:
if "reference" in message:
message.pop("reference")
res = self.post(f"/chats/{self.chat_id}/completions",
{"question": question, "stream": True,"session_id":self.id}, stream=stream)
def ask(self, question):
if self.__session_type == "agent":
res=self._ask_agent(question)
elif self.__session_type == "chat":
res=self._ask_chat(question)
for line in res.iter_lines():
line = line.decode("utf-8")
if line.startswith("{"):
@ -33,25 +38,20 @@ class Session(Base):
}
if "chunks" in reference:
chunks = reference["chunks"]
chunk_list = []
for chunk in chunks:
new_chunk = {
"id": chunk["chunk_id"],
"content": chunk["content_with_weight"],
"document_id": chunk["doc_id"],
"document_name": chunk["docnm_kwd"],
"dataset_id": chunk["kb_id"],
"image_id": chunk["img_id"],
"similarity": chunk["similarity"],
"vector_similarity": chunk["vector_similarity"],
"term_similarity": chunk["term_similarity"],
"positions": chunk["positions"],
}
chunk_list.append(new_chunk)
temp_dict["reference"] = chunk_list
temp_dict["reference"] = chunks
message = Message(self.rag, temp_dict)
yield message
def _ask_chat(self, question: str, stream: bool = False):
res = self.post(f"/chats/{self.chat_id}/completions",
{"question": question, "stream": True,"session_id":self.id}, stream=stream)
return res
def _ask_agent(self,question:str,stream:bool=False):
res = self.post(f"/agents/{self.agent_id}/completions",
{"question": question, "stream": True,"session_id":self.id}, stream=stream)
return res
def update(self,update_message):
res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}",
update_message)
@ -66,20 +66,4 @@ class Message(Base):
self.role = "assistant"
self.prompt = None
self.id = None
super().__init__(rag, res_dict)
class Chunk(Base):
def __init__(self, rag, res_dict):
self.id = None
self.content = None
self.document_id = ""
self.document_name = ""
self.dataset_id = ""
self.image_id = ""
self.similarity = None
self.vector_similarity = None
self.term_similarity = None
self.positions = None
super().__init__(rag, res_dict)
super().__init__(rag, res_dict)

View File

@ -1,5 +1,6 @@
from ragflow_sdk import RAGFlow
from ragflow_sdk import RAGFlow,Agent
from common import HOST_ADDRESS
import pytest
def test_create_session_with_success(get_api_key_fixture):
@ -58,6 +59,7 @@ def test_delete_sessions_with_success(get_api_key_fixture):
session = assistant.create_session()
assistant.delete_sessions(ids=[session.id])
def test_update_session_with_name(get_api_key_fixture):
API_KEY = get_api_key_fixture
rag = RAGFlow(API_KEY, HOST_ADDRESS)
@ -92,4 +94,17 @@ def test_list_sessions_with_success(get_api_key_fixture):
assistant=rag.create_chat("test_list_session", dataset_ids=[kb.id])
assistant.create_session("test_1")
assistant.create_session("test_2")
assistant.list_sessions()
assistant.list_sessions()
@pytest.mark.skip(reason="")
def test_create_agent_session_with_success(get_api_key_fixture):
API_KEY = "ragflow-BkOGNhYjIyN2JiODExZWY5MzVhMDI0Mm"
rag = RAGFlow(API_KEY,HOST_ADDRESS)
Agent.create_session("2e45b5209c1011efa3e90242ac120006", rag)
@pytest.mark.skip(reason="")
def test_create_agent_conversation_with_success(get_api_key_fixture):
API_KEY = "ragflow-BkOGNhYjIyN2JiODExZWY5MzVhMDI0Mm"
rag = RAGFlow(API_KEY,HOST_ADDRESS)
session = Agent.create_session("2e45b5209c1011efa3e90242ac120006", rag)
session.ask("What is this job")