Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add runtime model management api #540

Merged
merged 6 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/aibrix/aibrix/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import time
from pathlib import Path
from typing import Optional
from urllib.parse import urljoin

import uvicorn
Expand All @@ -24,8 +25,11 @@
REGISTRY,
)
from aibrix.openapi.engine.base import InferenceEngine, get_inference_engine
from aibrix.openapi.model import ModelManager
from aibrix.openapi.protocol import (
DownloadModelRequest,
ErrorResponse,
ListModelRequest,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest,
)
Expand Down Expand Up @@ -120,6 +124,24 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Re
return Response(status_code=200, content=response)


@router.post("/v1/model/download")
async def download_model(request: DownloadModelRequest):
response = await ModelManager.model_download(request)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is async call right? from client perspective, how do I know when it's finished? so I can orchestrate model loading request afterwards?

Copy link
Collaborator Author

@brosoul brosoul Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is async call right?

I initially planned to implement this async using coroutines, but later did not follow this approach 🤣 .
However, I am wondering if it is necessary to ensure that all API interfaces are aysnc? Or it can be partially aysnc and partially sync?

how do I know when it's finished?

Keep calling the post API until the model's status returns downloaded. Because this API will directly return the model status that implemented in #539 . And if necessary, a new process will be opened in the background for downloading, it will not wait for the download to complete before returning the result

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model status could be sync. model download should be async but it introduces some complexity in orchestration. this is acceptable at this moment. I will get you involved in a meeting. VKE team is integrating this part

if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), status_code=response.code)

return JSONResponse(status_code=200, content=response.model_dump())


@router.get("/v1/model/list")
async def list_model(request: Optional[ListModelRequest] = None):
response = await ModelManager.model_list(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), status_code=response.code)

return JSONResponse(status_code=200, content=response.model_dump())


@router.get("/healthz")
async def liveness_check():
# Simply return a 200 status for liveness check
Expand Down
13 changes: 13 additions & 0 deletions python/aibrix/aibrix/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
56 changes: 56 additions & 0 deletions python/aibrix/aibrix/common/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional


class InvalidArgumentError(ValueError):
pass


class ArgNotCongiuredError(InvalidArgumentError):
def __init__(self, arg_name: str, arg_source: Optional[str] = None):
self.arg_name = arg_name
self.message = f"Argument `{arg_name}` is not configured" + (
f" please check {arg_source}" if arg_source else ""
)
super().__init__(self.message)

def __str__(self):
return self.message


class ArgNotFormatError(InvalidArgumentError):
def __init__(self, arg_name: str, expected_format: str):
self.arg_name = arg_name
self.message = (
f"Argument `{arg_name}` is not in the expected format: {expected_format}"
)
super().__init__(self.message)

def __str__(self):
return self.message


class ModelNotFoundError(Exception):
def __init__(self, model_uri: str, detail_msg: Optional[str] = None):
self.model_uri = model_uri
self.message = f"Model not found at URI: {model_uri}" + (
f"\nDetails: {detail_msg}" if detail_msg else ""
)
super().__init__(self.message)

def __str__(self):
return self.message
7 changes: 5 additions & 2 deletions python/aibrix/aibrix/downloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Dict, Optional

from aibrix.downloader.base import get_downloader

Expand All @@ -21,6 +21,7 @@ def download_model(
model_uri: str,
local_path: Optional[str] = None,
model_name: Optional[str] = None,
download_extra_config: Optional[Dict] = None,
enable_progress_bar: bool = False,
):
"""Download model from model_uri to local_path.
Expand All @@ -30,7 +31,9 @@ def download_model(
local_path (str): local path to save model.
"""

downloader = get_downloader(model_uri, model_name, enable_progress_bar)
downloader = get_downloader(
model_uri, model_name, download_extra_config, enable_progress_bar
)
return downloader.download_model(local_path)


Expand Down
23 changes: 22 additions & 1 deletion python/aibrix/aibrix/downloader/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import argparse
import json
from typing import Dict, Optional

from aibrix.downloader import download_model


def str_to_dict(s) -> Optional[Dict]:
if s is None:
return None
try:
return json.loads(s)
except Exception as e:
raise ValueError(f"Invalid json string {s}") from e


def main():
parser = argparse.ArgumentParser(description="Download model from HuggingFace")
parser.add_argument(
Expand Down Expand Up @@ -30,9 +41,19 @@ def main():
default=False,
help="Enable download progress bar during downloading from TOS or S3",
)
parser.add_argument(
"--download-extra-config",
type=str_to_dict,
default=None,
help="Extra config for download, like auth config, parallel config, etc.",
)
args = parser.parse_args()
download_model(
args.model_uri, args.local_dir, args.model_name, args.enable_progress_bar
args.model_uri,
args.local_dir,
args.model_name,
args.download_extra_config,
args.enable_progress_bar,
)


Expand Down
82 changes: 71 additions & 11 deletions python/aibrix/aibrix/downloader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,45 @@
from concurrent.futures import ThreadPoolExecutor, wait
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
from typing import ClassVar, Dict, List, Optional

from aibrix import envs
from aibrix.downloader.entity import RemoteSource
from aibrix.logger import init_logger

logger = init_logger(__name__)


@dataclass
class DownloadExtraConfig:
"""Downloader extra config."""

# Auth config for s3 or tos
ak: Optional[str] = None
sk: Optional[str] = None
endpoint: Optional[str] = None
region: Optional[str] = None

# Auth config for huggingface
hf_endpoint: Optional[str] = None
hf_token: Optional[str] = None
hf_revision: Optional[str] = None

# parrallel config
num_threads: Optional[int] = None
max_io_queue: Optional[int] = None
io_chunksize: Optional[int] = None
part_threshold: Optional[int] = None
part_chunksize: Optional[int] = None

# other config
allow_file_suffix: Optional[List[str]] = None
force_download: Optional[bool] = None


DEFAULT_DOWNLOADER_EXTRA_CONFIG = DownloadExtraConfig()


@dataclass
class BaseDownloader(ABC):
"""Base class for downloader."""
Expand All @@ -34,15 +65,27 @@ class BaseDownloader(ABC):
model_name: str
bucket_path: str
bucket_name: Optional[str]
enable_progress_bar: bool = False
allow_file_suffix: Optional[List[str]] = field(
default_factory=lambda: envs.DOWNLOADER_ALLOW_FILE_SUFFIX
download_extra_config: DownloadExtraConfig = field(
default_factory=DownloadExtraConfig
)
enable_progress_bar: bool = False
_source: ClassVar[RemoteSource] = RemoteSource.UNKNOWN

def __post_init__(self):
# valid downloader config
self._valid_config()
self.model_name_path = self.model_name
self.allow_file_suffix = (
self.download_extra_config.allow_file_suffix
or envs.DOWNLOADER_ALLOW_FILE_SUFFIX
)
self.force_download = (
self.download_extra_config.force_download or envs.DOWNLOADER_FORCE_DOWNLOAD
)

@property
def source(self) -> RemoteSource:
return self._source

@abstractmethod
def _valid_config(self):
Expand Down Expand Up @@ -81,7 +124,7 @@ def download_directory(self, local_path: Path):
# filter the directory path
files = [file for file in directory_list if not file.endswith("/")]

if self.allow_file_suffix is None:
if self.allow_file_suffix is None or len(self.allow_file_suffix) == 0:
logger.info(f"All files from {self.bucket_path} will be downloaded.")
filtered_files = files
else:
Expand All @@ -93,7 +136,9 @@ def download_directory(self, local_path: Path):

if not self._support_range_download():
# download using multi threads
num_threads = envs.DOWNLOADER_NUM_THREADS
num_threads = (
self.download_extra_config.num_threads or envs.DOWNLOADER_NUM_THREADS
)
logger.info(
f"Downloader {self.__class__.__name__} download "
f"{len(filtered_files)} files from {self.model_uri} "
Expand Down Expand Up @@ -157,23 +202,38 @@ def __hash__(self):


def get_downloader(
model_uri: str, model_name: Optional[str] = None, enable_progress_bar: bool = False
model_uri: str,
model_name: Optional[str] = None,
download_extra_config: Optional[Dict] = None,
enable_progress_bar: bool = False,
) -> BaseDownloader:
"""Get downloader for model_uri."""
download_config: DownloadExtraConfig = (
DEFAULT_DOWNLOADER_EXTRA_CONFIG
if download_extra_config is None
else DownloadExtraConfig(**download_extra_config)
)

if re.match(envs.DOWNLOADER_S3_REGEX, model_uri):
from aibrix.downloader.s3 import S3Downloader

return S3Downloader(model_uri, model_name, enable_progress_bar)
return S3Downloader(model_uri, model_name, download_config, enable_progress_bar)
elif re.match(envs.DOWNLOADER_TOS_REGEX, model_uri):
if envs.DOWNLOADER_TOS_VERSION == "v1":
from aibrix.downloader.tos import TOSDownloaderV1

return TOSDownloaderV1(model_uri, model_name, enable_progress_bar)
return TOSDownloaderV1(
model_uri, model_name, download_config, enable_progress_bar
)
else:
from aibrix.downloader.tos import TOSDownloaderV2

return TOSDownloaderV2(model_uri, model_name, enable_progress_bar)
return TOSDownloaderV2(
model_uri, model_name, download_config, enable_progress_bar
)
else:
from aibrix.downloader.huggingface import HuggingFaceDownloader

return HuggingFaceDownloader(model_uri, model_name, enable_progress_bar)
return HuggingFaceDownloader(
model_uri, model_name, download_config, enable_progress_bar
)
20 changes: 20 additions & 0 deletions python/aibrix/aibrix/downloader/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class RemoteSource(Enum):
S3 = "s3"
TOS = "tos"
HUGGINGFACE = "huggingface"
UNKNOWN = "unknown"

def __str__(self):
return self.value


class FileDownloadStatus(Enum):
Expand All @@ -39,13 +43,20 @@ class FileDownloadStatus(Enum):
NO_OPERATION = "no_operation" # Interrupted from downloading
UNKNOWN = "unknown"

def __str__(self):
return self.value


class ModelDownloadStatus(Enum):
NOT_EXIST = "not_exist"
DOWNLOADING = "downloading"
DOWNLOADED = "downloaded"
NO_OPERATION = "no_operation" # Interrupted from downloading
UNKNOWN = "unknown"

def __str__(self):
return self.value


@dataclass
class DownloadFile:
Expand Down Expand Up @@ -125,13 +136,22 @@ def status(self):

return ModelDownloadStatus.UNKNOWN

@property
def model_root_path(self) -> Path:
return Path(self.local_path).joinpath(self.model_name)

@classmethod
def infer_from_model_path(
cls, local_path: Path, model_name: str, source: RemoteSource
) -> Optional["DownloadModel"]:
assert source is not None

model_base_dir = Path(local_path).joinpath(model_name)

# model not exists
if not model_base_dir.exists():
return None

cache_sub_dir = (DOWNLOAD_CACHE_DIR % source.value).strip("/")
cache_dir = Path(model_base_dir).joinpath(cache_sub_dir)
lock_files = list(Path(cache_dir).glob("*.lock"))
Expand Down
Loading
Loading