diff --git a/download_deps.py b/download_deps.py index 10e08bdda..56e148e7a 100644 --- a/download_deps.py +++ b/download_deps.py @@ -14,6 +14,7 @@ import urllib.request from typing import Union import nltk +from huggingface_hub import snapshot_download def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]: @@ -39,6 +40,19 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]: ] +repos = [ + "InfiniFlow/text_concat_xgb_v1.0", + "InfiniFlow/deepdoc", + "InfiniFlow/huqie", +] + + +def download_model(repo_id): + local_dir = os.path.abspath(os.path.join("huggingface.co", repo_id)) + os.makedirs(local_dir, exist_ok=True) + snapshot_download(repo_id=repo_id, local_dir=local_dir) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download dependencies with optional China mirror support") parser.add_argument("--china-mirrors", action="store_true", help="Use China-accessible mirrors for downloads") @@ -57,3 +71,7 @@ if __name__ == "__main__": for data in ["wordnet", "punkt", "punkt_tab"]: print(f"Downloading nltk {data}...") nltk.download(data, download_dir=local_dir) + + for repo_id in repos: + print(f"Downloading huggingface repo {repo_id}...") + download_model(repo_id)