mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-31 15:45:08 +08:00
Refa: asyncio.to_thread to ThreadPoolExecutor to break thread limitat… (#12716)
### Type of change - [x] Refactoring
This commit is contained in:
@ -14,15 +14,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import hashlib
|
||||
import uuid
|
||||
import requests
|
||||
import threading
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
@ -106,3 +111,22 @@ def pip_install_torch():
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
|
||||
|
||||
def _thread_pool_executor():
|
||||
max_workers_env = os.getenv("THREAD_POOL_MAX_WORKERS", "128")
|
||||
try:
|
||||
max_workers = int(max_workers_env)
|
||||
except ValueError:
|
||||
max_workers = 128
|
||||
if max_workers < 1:
|
||||
max_workers = 1
|
||||
return ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
|
||||
async def thread_pool_exec(func, *args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
if kwargs:
|
||||
func = functools.partial(func, *args, **kwargs)
|
||||
return await loop.run_in_executor(_thread_pool_executor(), func)
|
||||
return await loop.run_in_executor(_thread_pool_executor(), func, *args)
|
||||
|
||||
Reference in New Issue
Block a user