Format file format from Windows/dos to Unix (#1949)

### What problem does this PR solve?

Related source file is in Windows/DOS format, they are format to Unix
format.

### Type of change

- [x] Refactoring

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2024-08-15 09:17:36 +08:00
committed by GitHub
parent 1328d715db
commit 6b3a40be5c
108 changed files with 36399 additions and 36399 deletions

View File

@ -1,171 +1,171 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import pickle
import random
import time
from copy import deepcopy
from multiprocessing.connection import Listener
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
def torch_gc():
try:
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
pass
except Exception:
pass
class RPCHandler:
def __init__(self):
self._functions = {}
def register_function(self, func):
self._functions[func.__name__] = func
def handle_connection(self, connection):
try:
while True:
# Receive a message
func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response
try:
r = self._functions[func_name](*args, **kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
def rpc_server(hdlr, address, authkey):
sock = Listener(address, authkey=authkey)
while True:
try:
client = sock.accept()
t = Thread(target=hdlr.handle_connection, args=(client,))
t.daemon = True
t.start()
except Exception as e:
print("【EXCEPTION】:", str(e))
models = []
tokenizer = None
def chat(messages, gen_conf):
global tokenizer
model = Model()
try:
torch_gc()
conf = {
"max_new_tokens": int(
gen_conf.get(
"max_tokens", 256)), "temperature": float(
gen_conf.get(
"temperature", 0.1))}
print(messages, conf)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
model_inputs.input_ids,
**conf
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(
generated_ids, skip_special_tokens=True)[0]
except Exception as e:
return str(e)
def chat_streamly(messages, gen_conf):
global tokenizer
model = Model()
try:
torch_gc()
conf = deepcopy(gen_conf)
print(messages, conf)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextStreamer(tokenizer)
conf["inputs"] = model_inputs.input_ids
conf["streamer"] = streamer
conf["max_new_tokens"] = conf["max_tokens"]
del conf["max_tokens"]
thread = Thread(target=model.generate, kwargs=conf)
thread.start()
for _, new_text in enumerate(streamer):
yield new_text
except Exception as e:
yield "**ERROR**: " + str(e)
def Model():
global models
random.seed(time.time())
return random.choice(models)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument(
"--port",
default=7860,
type=int,
help="RPC serving port")
args = parser.parse_args()
handler = RPCHandler()
handler.register_function(chat)
handler.register_function(chat_streamly)
models = []
for _ in range(1):
m = AutoModelForCausalLM.from_pretrained(args.model_name,
device_map="auto",
torch_dtype='auto')
models.append(m)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Run the server
rpc_server(handler, ('0.0.0.0', args.port),
authkey=b'infiniflow-token4kevinhu')
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import pickle
import random
import time
from copy import deepcopy
from multiprocessing.connection import Listener
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
def torch_gc():
try:
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
pass
except Exception:
pass
class RPCHandler:
def __init__(self):
self._functions = {}
def register_function(self, func):
self._functions[func.__name__] = func
def handle_connection(self, connection):
try:
while True:
# Receive a message
func_name, args, kwargs = pickle.loads(connection.recv())
# Run the RPC and send a response
try:
r = self._functions[func_name](*args, **kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
def rpc_server(hdlr, address, authkey):
sock = Listener(address, authkey=authkey)
while True:
try:
client = sock.accept()
t = Thread(target=hdlr.handle_connection, args=(client,))
t.daemon = True
t.start()
except Exception as e:
print("【EXCEPTION】:", str(e))
models = []
tokenizer = None
def chat(messages, gen_conf):
global tokenizer
model = Model()
try:
torch_gc()
conf = {
"max_new_tokens": int(
gen_conf.get(
"max_tokens", 256)), "temperature": float(
gen_conf.get(
"temperature", 0.1))}
print(messages, conf)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
model_inputs.input_ids,
**conf
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(
generated_ids, skip_special_tokens=True)[0]
except Exception as e:
return str(e)
def chat_streamly(messages, gen_conf):
global tokenizer
model = Model()
try:
torch_gc()
conf = deepcopy(gen_conf)
print(messages, conf)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextStreamer(tokenizer)
conf["inputs"] = model_inputs.input_ids
conf["streamer"] = streamer
conf["max_new_tokens"] = conf["max_tokens"]
del conf["max_tokens"]
thread = Thread(target=model.generate, kwargs=conf)
thread.start()
for _, new_text in enumerate(streamer):
yield new_text
except Exception as e:
yield "**ERROR**: " + str(e)
def Model():
global models
random.seed(time.time())
return random.choice(models)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument(
"--port",
default=7860,
type=int,
help="RPC serving port")
args = parser.parse_args()
handler = RPCHandler()
handler.register_function(chat)
handler.register_function(chat_streamly)
models = []
for _ in range(1):
m = AutoModelForCausalLM.from_pretrained(args.model_name,
device_map="auto",
torch_dtype='auto')
models.append(m)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Run the server
rpc_server(handler, ('0.0.0.0', args.port),
authkey=b'infiniflow-token4kevinhu')

View File

@ -1,89 +1,89 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from openai import OpenAI
import os
import json
from rag.utils import num_tokens_from_string
class Base(ABC):
def __init__(self, key, model_name):
pass
def transcription(self, audio, **kwargs):
transcription = self.client.audio.transcriptions.create(
model=self.model_name,
file=audio,
response_format="text"
)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
class GPTSeq2txt(Base):
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class QWenSeq2txt(Base):
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def transcription(self, audio, format):
from http import HTTPStatus
from dashscope.audio.asr import Recognition
recognition = Recognition(model=self.model_name,
format=format,
sample_rate=16000,
callback=None)
result = recognition.call(audio)
ans = ""
if result.status_code == HTTPStatus.OK:
for sentence in result.get_sentence():
ans += str(sentence + '\n')
return ans, num_tokens_from_string(ans)
return "**ERROR**: " + result.message, 0
class OllamaSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang
class XinferenceSeq2txt(Base):
def __init__(self, key, model_name="", base_url=""):
self.client = OpenAI(api_key="xxx", base_url=base_url)
self.model_name = model_name
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from openai import OpenAI
import os
import json
from rag.utils import num_tokens_from_string
class Base(ABC):
def __init__(self, key, model_name):
pass
def transcription(self, audio, **kwargs):
transcription = self.client.audio.transcriptions.create(
model=self.model_name,
file=audio,
response_format="text"
)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
class GPTSeq2txt(Base):
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class QWenSeq2txt(Base):
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def transcription(self, audio, format):
from http import HTTPStatus
from dashscope.audio.asr import Recognition
recognition = Recognition(model=self.model_name,
format=format,
sample_rate=16000,
callback=None)
result = recognition.call(audio)
ans = ""
if result.status_code == HTTPStatus.OK:
for sentence in result.get_sentence():
ans += str(sentence + '\n')
return ans, num_tokens_from_string(ans)
return "**ERROR**: " + result.message, 0
class OllamaSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang
class XinferenceSeq2txt(Base):
def __init__(self, key, model_name="", base_url=""):
self.client = OpenAI(api_key="xxx", base_url=base_url)
self.model_name = model_name