mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Test APIs and fix bugs (#41)
This commit is contained in:
@ -19,31 +19,39 @@ import os
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo"):
|
||||
self.client = OpenAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
res = self.client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
**gen_conf)
|
||||
return res.choices[0].message.content.strip()
|
||||
|
||||
|
||||
from dashscope import Generation
|
||||
class QWenChat(Base):
|
||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo):
|
||||
import dashscope
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
from dashscope import Generation
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
response = Generation.call(
|
||||
Generation.Models.qwen_turbo,
|
||||
self.model_name,
|
||||
messages=history,
|
||||
result_format='message'
|
||||
)
|
||||
|
||||
@ -28,6 +28,8 @@ class Base(ABC):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def image2base64(self, image):
|
||||
if isinstance(image, bytes):
|
||||
return base64.b64encode(image).decode("utf-8")
|
||||
if isinstance(image, BytesIO):
|
||||
return base64.b64encode(image.getvalue()).decode("utf-8")
|
||||
buffered = BytesIO()
|
||||
@ -59,7 +61,7 @@ class Base(ABC):
|
||||
|
||||
class GptV4(Base):
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview"):
|
||||
self.client = OpenAI(key)
|
||||
self.client = OpenAI(api_key = key)
|
||||
self.model_name = model_name
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
|
||||
Reference in New Issue
Block a user