466 lines
20 KiB
Python
466 lines
20 KiB
Python
|
|
import os
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Iterable, List, Literal, Optional, Union, overload
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
from tqdm.auto import tqdm as base_tqdm
|
||
|
|
from tqdm.contrib.concurrent import thread_map
|
||
|
|
|
||
|
|
from . import constants
|
||
|
|
from .errors import (
|
||
|
|
DryRunError,
|
||
|
|
GatedRepoError,
|
||
|
|
HfHubHTTPError,
|
||
|
|
LocalEntryNotFoundError,
|
||
|
|
RepositoryNotFoundError,
|
||
|
|
RevisionNotFoundError,
|
||
|
|
)
|
||
|
|
from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
|
||
|
|
from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
|
||
|
|
from .utils import OfflineModeIsEnabled, filter_repo_objects, is_tqdm_disabled, logging, validate_hf_hub_args
|
||
|
|
from .utils import tqdm as hf_tqdm
|
||
|
|
|
||
|
|
|
||
|
|
logger = logging.get_logger(__name__)
|
||
|
|
|
||
|
|
LARGE_REPO_THRESHOLD = 1000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def snapshot_download(
|
||
|
|
repo_id: str,
|
||
|
|
*,
|
||
|
|
repo_type: Optional[str] = None,
|
||
|
|
revision: Optional[str] = None,
|
||
|
|
cache_dir: Union[str, Path, None] = None,
|
||
|
|
local_dir: Union[str, Path, None] = None,
|
||
|
|
library_name: Optional[str] = None,
|
||
|
|
library_version: Optional[str] = None,
|
||
|
|
user_agent: Optional[Union[dict, str]] = None,
|
||
|
|
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
|
||
|
|
force_download: bool = False,
|
||
|
|
token: Optional[Union[bool, str]] = None,
|
||
|
|
local_files_only: bool = False,
|
||
|
|
allow_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
ignore_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
max_workers: int = 8,
|
||
|
|
tqdm_class: Optional[type[base_tqdm]] = None,
|
||
|
|
headers: Optional[dict[str, str]] = None,
|
||
|
|
endpoint: Optional[str] = None,
|
||
|
|
dry_run: Literal[False] = False,
|
||
|
|
) -> str: ...
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def snapshot_download(
|
||
|
|
repo_id: str,
|
||
|
|
*,
|
||
|
|
repo_type: Optional[str] = None,
|
||
|
|
revision: Optional[str] = None,
|
||
|
|
cache_dir: Union[str, Path, None] = None,
|
||
|
|
local_dir: Union[str, Path, None] = None,
|
||
|
|
library_name: Optional[str] = None,
|
||
|
|
library_version: Optional[str] = None,
|
||
|
|
user_agent: Optional[Union[dict, str]] = None,
|
||
|
|
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
|
||
|
|
force_download: bool = False,
|
||
|
|
token: Optional[Union[bool, str]] = None,
|
||
|
|
local_files_only: bool = False,
|
||
|
|
allow_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
ignore_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
max_workers: int = 8,
|
||
|
|
tqdm_class: Optional[type[base_tqdm]] = None,
|
||
|
|
headers: Optional[dict[str, str]] = None,
|
||
|
|
endpoint: Optional[str] = None,
|
||
|
|
dry_run: Literal[True] = True,
|
||
|
|
) -> list[DryRunFileInfo]: ...
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def snapshot_download(
|
||
|
|
repo_id: str,
|
||
|
|
*,
|
||
|
|
repo_type: Optional[str] = None,
|
||
|
|
revision: Optional[str] = None,
|
||
|
|
cache_dir: Union[str, Path, None] = None,
|
||
|
|
local_dir: Union[str, Path, None] = None,
|
||
|
|
library_name: Optional[str] = None,
|
||
|
|
library_version: Optional[str] = None,
|
||
|
|
user_agent: Optional[Union[dict, str]] = None,
|
||
|
|
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
|
||
|
|
force_download: bool = False,
|
||
|
|
token: Optional[Union[bool, str]] = None,
|
||
|
|
local_files_only: bool = False,
|
||
|
|
allow_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
ignore_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
max_workers: int = 8,
|
||
|
|
tqdm_class: Optional[type[base_tqdm]] = None,
|
||
|
|
headers: Optional[dict[str, str]] = None,
|
||
|
|
endpoint: Optional[str] = None,
|
||
|
|
dry_run: bool = False,
|
||
|
|
) -> Union[str, list[DryRunFileInfo]]: ...
|
||
|
|
|
||
|
|
|
||
|
|
@validate_hf_hub_args
|
||
|
|
def snapshot_download(
|
||
|
|
repo_id: str,
|
||
|
|
*,
|
||
|
|
repo_type: Optional[str] = None,
|
||
|
|
revision: Optional[str] = None,
|
||
|
|
cache_dir: Union[str, Path, None] = None,
|
||
|
|
local_dir: Union[str, Path, None] = None,
|
||
|
|
library_name: Optional[str] = None,
|
||
|
|
library_version: Optional[str] = None,
|
||
|
|
user_agent: Optional[Union[dict, str]] = None,
|
||
|
|
etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT,
|
||
|
|
force_download: bool = False,
|
||
|
|
token: Optional[Union[bool, str]] = None,
|
||
|
|
local_files_only: bool = False,
|
||
|
|
allow_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
ignore_patterns: Optional[Union[list[str], str]] = None,
|
||
|
|
max_workers: int = 8,
|
||
|
|
tqdm_class: Optional[type[base_tqdm]] = None,
|
||
|
|
headers: Optional[dict[str, str]] = None,
|
||
|
|
endpoint: Optional[str] = None,
|
||
|
|
dry_run: bool = False,
|
||
|
|
) -> Union[str, list[DryRunFileInfo]]:
|
||
|
|
"""Download repo files.
|
||
|
|
|
||
|
|
Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
|
||
|
|
a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
|
||
|
|
to keep their actual filename relative to that folder. You can also filter which files to download using
|
||
|
|
`allow_patterns` and `ignore_patterns`.
|
||
|
|
|
||
|
|
If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this
|
||
|
|
option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir`
|
||
|
|
to store some metadata related to the downloaded files. While this mechanism is not as robust as the main
|
||
|
|
cache-system, it's optimized for regularly pulling the latest version of a repository.
|
||
|
|
|
||
|
|
An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
|
||
|
|
configured. It is also not possible to filter which files to download when cloning a repository using git.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
repo_id (`str`):
|
||
|
|
A user or an organization name and a repo name separated by a `/`.
|
||
|
|
repo_type (`str`, *optional*):
|
||
|
|
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
|
||
|
|
`None` or `"model"` if downloading from a model. Default is `None`.
|
||
|
|
revision (`str`, *optional*):
|
||
|
|
An optional Git revision id which can be a branch name, a tag, or a
|
||
|
|
commit hash.
|
||
|
|
cache_dir (`str`, `Path`, *optional*):
|
||
|
|
Path to the folder where cached files are stored.
|
||
|
|
local_dir (`str` or `Path`, *optional*):
|
||
|
|
If provided, the downloaded files will be placed under this directory.
|
||
|
|
library_name (`str`, *optional*):
|
||
|
|
The name of the library to which the object corresponds.
|
||
|
|
library_version (`str`, *optional*):
|
||
|
|
The version of the library.
|
||
|
|
user_agent (`str`, `dict`, *optional*):
|
||
|
|
The user-agent info in the form of a dictionary or a string.
|
||
|
|
etag_timeout (`float`, *optional*, defaults to `10`):
|
||
|
|
When fetching ETag, how many seconds to wait for the server to send
|
||
|
|
data before giving up which is passed to `httpx.request`.
|
||
|
|
force_download (`bool`, *optional*, defaults to `False`):
|
||
|
|
Whether the file should be downloaded even if it already exists in the local cache.
|
||
|
|
token (`str`, `bool`, *optional*):
|
||
|
|
A token to be used for the download.
|
||
|
|
- If `True`, the token is read from the HuggingFace config
|
||
|
|
folder.
|
||
|
|
- If a string, it's used as the authentication token.
|
||
|
|
headers (`dict`, *optional*):
|
||
|
|
Additional headers to include in the request. Those headers take precedence over the others.
|
||
|
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||
|
|
If `True`, avoid downloading the file and return the path to the
|
||
|
|
local cached file if it exists.
|
||
|
|
allow_patterns (`list[str]` or `str`, *optional*):
|
||
|
|
If provided, only files matching at least one pattern are downloaded.
|
||
|
|
ignore_patterns (`list[str]` or `str`, *optional*):
|
||
|
|
If provided, files matching any of the patterns are not downloaded.
|
||
|
|
max_workers (`int`, *optional*):
|
||
|
|
Number of concurrent threads to download files (1 thread = 1 file download).
|
||
|
|
Defaults to 8.
|
||
|
|
tqdm_class (`tqdm`, *optional*):
|
||
|
|
If provided, overwrites the default behavior for the progress bar. Passed
|
||
|
|
argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
|
||
|
|
Note that the `tqdm_class` is not passed to each individual download.
|
||
|
|
Defaults to the custom HF progress bar that can be disabled by setting
|
||
|
|
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
|
||
|
|
dry_run (`bool`, *optional*, defaults to `False`):
|
||
|
|
If `True`, perform a dry run without actually downloading the files. Returns a list of
|
||
|
|
[`DryRunFileInfo`] objects containing information about what would be downloaded.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`str` or list of [`DryRunFileInfo`]:
|
||
|
|
- If `dry_run=False`: Local snapshot path.
|
||
|
|
- If `dry_run=True`: A list of [`DryRunFileInfo`] objects containing download information.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
[`~utils.RepositoryNotFoundError`]
|
||
|
|
If the repository to download from cannot be found. This may be because it doesn't exist,
|
||
|
|
or because it is set to `private` and you do not have access.
|
||
|
|
[`~utils.RevisionNotFoundError`]
|
||
|
|
If the revision to download from cannot be found.
|
||
|
|
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||
|
|
If `token=True` and the token cannot be found.
|
||
|
|
[`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||
|
|
ETag cannot be determined.
|
||
|
|
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||
|
|
if some parameter value is invalid.
|
||
|
|
"""
|
||
|
|
if cache_dir is None:
|
||
|
|
cache_dir = constants.HF_HUB_CACHE
|
||
|
|
if revision is None:
|
||
|
|
revision = constants.DEFAULT_REVISION
|
||
|
|
if isinstance(cache_dir, Path):
|
||
|
|
cache_dir = str(cache_dir)
|
||
|
|
|
||
|
|
if repo_type is None:
|
||
|
|
repo_type = "model"
|
||
|
|
if repo_type not in constants.REPO_TYPES:
|
||
|
|
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(constants.REPO_TYPES)}")
|
||
|
|
|
||
|
|
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
|
||
|
|
|
||
|
|
api = HfApi(
|
||
|
|
library_name=library_name,
|
||
|
|
library_version=library_version,
|
||
|
|
user_agent=user_agent,
|
||
|
|
endpoint=endpoint,
|
||
|
|
headers=headers,
|
||
|
|
token=token,
|
||
|
|
)
|
||
|
|
|
||
|
|
repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
|
||
|
|
api_call_error: Optional[Exception] = None
|
||
|
|
if not local_files_only:
|
||
|
|
# try/except logic to handle different errors => taken from `hf_hub_download`
|
||
|
|
try:
|
||
|
|
# if we have internet connection we want to list files to download
|
||
|
|
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision)
|
||
|
|
except httpx.ProxyError:
|
||
|
|
# Actually raise on proxy error
|
||
|
|
raise
|
||
|
|
except (httpx.ConnectError, httpx.TimeoutException, OfflineModeIsEnabled) as error:
|
||
|
|
# Internet connection is down
|
||
|
|
# => will try to use local files only
|
||
|
|
api_call_error = error
|
||
|
|
pass
|
||
|
|
except RevisionNotFoundError:
|
||
|
|
# The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
|
||
|
|
raise
|
||
|
|
except HfHubHTTPError as error:
|
||
|
|
# Multiple reasons for an http error:
|
||
|
|
# - Repository is private and invalid/missing token sent
|
||
|
|
# - Repository is gated and invalid/missing token sent
|
||
|
|
# - Hub is down (error 500 or 504)
|
||
|
|
# => let's switch to 'local_files_only=True' to check if the files are already cached.
|
||
|
|
# (if it's not the case, the error will be re-raised)
|
||
|
|
api_call_error = error
|
||
|
|
pass
|
||
|
|
|
||
|
|
# At this stage, if `repo_info` is None it means either:
|
||
|
|
# - internet connection is down
|
||
|
|
# - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
|
||
|
|
# - repo is private/gated and invalid/missing token sent
|
||
|
|
# - Hub is down
|
||
|
|
# => let's look if we can find the appropriate folder in the cache:
|
||
|
|
# - if the specified revision is a commit hash, look inside "snapshots".
|
||
|
|
# - f the specified revision is a branch or tag, look inside "refs".
|
||
|
|
# => if local_dir is not None, we will return the path to the local folder if it exists.
|
||
|
|
if repo_info is None:
|
||
|
|
if dry_run:
|
||
|
|
raise DryRunError(
|
||
|
|
"Dry run cannot be performed as the repository cannot be accessed. Please check your internet connection or authentication token."
|
||
|
|
) from api_call_error
|
||
|
|
|
||
|
|
# Try to get which commit hash corresponds to the specified revision
|
||
|
|
commit_hash = None
|
||
|
|
if REGEX_COMMIT_HASH.match(revision):
|
||
|
|
commit_hash = revision
|
||
|
|
else:
|
||
|
|
ref_path = os.path.join(storage_folder, "refs", revision)
|
||
|
|
if os.path.exists(ref_path):
|
||
|
|
# retrieve commit_hash from refs file
|
||
|
|
with open(ref_path) as f:
|
||
|
|
commit_hash = f.read()
|
||
|
|
|
||
|
|
# Try to locate snapshot folder for this commit hash
|
||
|
|
if commit_hash is not None and local_dir is None:
|
||
|
|
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
||
|
|
if os.path.exists(snapshot_folder):
|
||
|
|
# Snapshot folder exists => let's return it
|
||
|
|
# (but we can't check if all the files are actually there)
|
||
|
|
return snapshot_folder
|
||
|
|
|
||
|
|
# If local_dir is not None, return it if it exists and is not empty
|
||
|
|
if local_dir is not None:
|
||
|
|
local_dir = Path(local_dir)
|
||
|
|
if local_dir.is_dir() and any(local_dir.iterdir()):
|
||
|
|
logger.warning(
|
||
|
|
f"Returning existing local_dir `{local_dir}` as remote repo cannot be accessed in `snapshot_download` ({api_call_error})."
|
||
|
|
)
|
||
|
|
return str(local_dir.resolve())
|
||
|
|
# If we couldn't find the appropriate folder on disk, raise an error.
|
||
|
|
if local_files_only:
|
||
|
|
raise LocalEntryNotFoundError(
|
||
|
|
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
|
||
|
|
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
|
||
|
|
"'local_files_only=False' as input."
|
||
|
|
)
|
||
|
|
elif isinstance(api_call_error, OfflineModeIsEnabled):
|
||
|
|
raise LocalEntryNotFoundError(
|
||
|
|
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
|
||
|
|
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
|
||
|
|
"'HF_HUB_OFFLINE=0' as environment variable."
|
||
|
|
) from api_call_error
|
||
|
|
elif isinstance(api_call_error, (RepositoryNotFoundError, GatedRepoError)) or (
|
||
|
|
isinstance(api_call_error, HfHubHTTPError) and api_call_error.response.status_code == 401
|
||
|
|
):
|
||
|
|
# Repo not found, gated, or specific authentication error => let's raise the actual error
|
||
|
|
raise api_call_error
|
||
|
|
else:
|
||
|
|
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
|
||
|
|
raise LocalEntryNotFoundError(
|
||
|
|
"An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
|
||
|
|
" snapshot folder for the specified revision on the local disk. Please check your internet connection"
|
||
|
|
" and try again."
|
||
|
|
) from api_call_error
|
||
|
|
|
||
|
|
# At this stage, internet connection is up and running
|
||
|
|
# => let's download the files!
|
||
|
|
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
|
||
|
|
|
||
|
|
# Corner case: on very large repos, the siblings list in `repo_info` might not contain all files.
|
||
|
|
# In that case, we need to use the `list_repo_tree` method to prevent caching issues.
|
||
|
|
repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] if repo_info.siblings is not None else []
|
||
|
|
unreliable_nb_files = (
|
||
|
|
repo_info.siblings is None or len(repo_info.siblings) == 0 or len(repo_info.siblings) > LARGE_REPO_THRESHOLD
|
||
|
|
)
|
||
|
|
if unreliable_nb_files:
|
||
|
|
logger.info(
|
||
|
|
"Number of files in the repo is unreliable. Using `list_repo_tree` to ensure all files are listed."
|
||
|
|
)
|
||
|
|
repo_files = (
|
||
|
|
f.rfilename
|
||
|
|
for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type)
|
||
|
|
if isinstance(f, RepoFile)
|
||
|
|
)
|
||
|
|
|
||
|
|
filtered_repo_files: Iterable[str] = filter_repo_objects(
|
||
|
|
items=repo_files,
|
||
|
|
allow_patterns=allow_patterns,
|
||
|
|
ignore_patterns=ignore_patterns,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not unreliable_nb_files:
|
||
|
|
filtered_repo_files = list(filtered_repo_files)
|
||
|
|
tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
|
||
|
|
else:
|
||
|
|
tqdm_desc = "Fetching ... files"
|
||
|
|
if dry_run:
|
||
|
|
tqdm_desc = "[dry-run] " + tqdm_desc
|
||
|
|
|
||
|
|
commit_hash = repo_info.sha
|
||
|
|
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
||
|
|
# if passed revision is not identical to commit_hash
|
||
|
|
# then revision has to be a branch name or tag name.
|
||
|
|
# In that case store a ref.
|
||
|
|
if revision != commit_hash:
|
||
|
|
ref_path = os.path.join(storage_folder, "refs", revision)
|
||
|
|
try:
|
||
|
|
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
|
||
|
|
with open(ref_path, "w") as f:
|
||
|
|
f.write(commit_hash)
|
||
|
|
except OSError as e:
|
||
|
|
logger.warning(f"Ignored error while writing commit hash to {ref_path}: {e}.")
|
||
|
|
|
||
|
|
results: List[Union[str, DryRunFileInfo]] = []
|
||
|
|
|
||
|
|
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
|
||
|
|
tqdm_class = tqdm_class or hf_tqdm
|
||
|
|
|
||
|
|
# Create a progress bar for the bytes downloaded
|
||
|
|
# This progress bar is shared across threads/files and gets updated each time we fetch
|
||
|
|
# metadata for a file.
|
||
|
|
bytes_progress = tqdm_class(
|
||
|
|
desc="Downloading (incomplete total...)",
|
||
|
|
disable=is_tqdm_disabled(log_level=logger.getEffectiveLevel()),
|
||
|
|
total=0,
|
||
|
|
initial=0,
|
||
|
|
unit="B",
|
||
|
|
unit_scale=True,
|
||
|
|
name="huggingface_hub.snapshot_download",
|
||
|
|
)
|
||
|
|
|
||
|
|
class _AggregatedTqdm:
|
||
|
|
"""Fake tqdm object to aggregate progress into the parent `bytes_progress` bar.
|
||
|
|
|
||
|
|
In practice the `_AggregatedTqdm` object won't be displayed, it's just used to update
|
||
|
|
the `bytes_progress` bar from each thread/file download.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, *args, **kwargs):
|
||
|
|
# Adjust the total of the parent progress bar
|
||
|
|
total = kwargs.pop("total", None)
|
||
|
|
if total is not None:
|
||
|
|
bytes_progress.total += total
|
||
|
|
bytes_progress.refresh()
|
||
|
|
|
||
|
|
# Adjust initial of the parent progress bar
|
||
|
|
initial = kwargs.pop("initial", 0)
|
||
|
|
if initial:
|
||
|
|
bytes_progress.update(initial)
|
||
|
|
|
||
|
|
def __enter__(self):
|
||
|
|
return self
|
||
|
|
|
||
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def update(self, n: Optional[Union[int, float]] = 1) -> None:
|
||
|
|
bytes_progress.update(n)
|
||
|
|
|
||
|
|
# we pass the commit_hash to hf_hub_download
|
||
|
|
# so no network call happens if we already
|
||
|
|
# have the file locally.
|
||
|
|
def _inner_hf_hub_download(repo_file: str) -> None:
|
||
|
|
results.append(
|
||
|
|
hf_hub_download( # type: ignore
|
||
|
|
repo_id,
|
||
|
|
filename=repo_file,
|
||
|
|
repo_type=repo_type,
|
||
|
|
revision=commit_hash,
|
||
|
|
endpoint=endpoint,
|
||
|
|
cache_dir=cache_dir,
|
||
|
|
local_dir=local_dir,
|
||
|
|
library_name=library_name,
|
||
|
|
library_version=library_version,
|
||
|
|
user_agent=user_agent,
|
||
|
|
etag_timeout=etag_timeout,
|
||
|
|
force_download=force_download,
|
||
|
|
token=token,
|
||
|
|
headers=headers,
|
||
|
|
tqdm_class=_AggregatedTqdm, # type: ignore
|
||
|
|
dry_run=dry_run,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
thread_map(
|
||
|
|
_inner_hf_hub_download,
|
||
|
|
filtered_repo_files,
|
||
|
|
desc=tqdm_desc,
|
||
|
|
max_workers=max_workers,
|
||
|
|
tqdm_class=tqdm_class,
|
||
|
|
)
|
||
|
|
|
||
|
|
bytes_progress.set_description("Download complete")
|
||
|
|
|
||
|
|
if dry_run:
|
||
|
|
assert all(isinstance(r, DryRunFileInfo) for r in results)
|
||
|
|
return results # type: ignore
|
||
|
|
|
||
|
|
if local_dir is not None:
|
||
|
|
return str(os.path.realpath(local_dir))
|
||
|
|
return snapshot_folder
|