Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced a model store #805

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,21 +1007,21 @@ def rm_cli(args):

def New(model, args):
if model.startswith("huggingface://") or model.startswith("hf://") or model.startswith("hf.co/"):
return Huggingface(model)
return Huggingface(model, args.store)
if model.startswith("ollama://") or "ollama.com/library/" in model:
return Ollama(model)
return Ollama(model, args.store)
if model.startswith("oci://") or model.startswith("docker://"):
return OCI(model, args.engine)
return OCI(model, args.engine, args.store)
if model.startswith("http://") or model.startswith("https://") or model.startswith("file://"):
return URL(model)
return URL(model, args.store)

transport = config.get("transport", "ollama")
if transport == "huggingface":
return Huggingface(model)
return Huggingface(model, args.store)
if transport == "ollama":
return Ollama(model)
return Ollama(model, args.store)
if transport == "oci":
return OCI(model, args.engine)
return OCI(model, args.engine, args.store)

raise KeyError(f'transport "{transport}" not supported. Must be oci, huggingface, or ollama.')

Expand Down
19 changes: 17 additions & 2 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

MNT_DIR = "/mnt/models"
MNT_FILE = f"{MNT_DIR}/model.file"
HTTP_NOT_FOUND = 404
HTTP_RANGE_NOT_SATISFIABLE = 416

DEFAULT_IMAGE = "quay.io/ramalama/ramalama"
Expand Down Expand Up @@ -154,6 +155,19 @@ def verify_checksum(filename):
# Compare the checksums
return sha256_hash.hexdigest() == expected_checksum

def generate_sha256(to_hash: str) -> str:
"""
Generates a sha256 for a string.
Args:
to_hash (str): The string to generate the sha256 hash for.
Returns:
str: Hex digest of the input appended to the prefix sha256:
"""
h = hashlib.new("sha256")
h.update(to_hash.encode("utf-8"))
return f"sha256:{h.hexdigest()}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this function, actually thinking about making a breaking change soon and changing ':' character to '-' on the filesystem like Ollama, just not for this PR. For one, on some filesystems ':' is an illegal character.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching : with - makes much sense.
Since this PR would introduce a breaking regarding the file storage anyway, I can include this as well.


# default_image function should figure out which GPU the system uses t
# then running appropriate container image.
Expand Down Expand Up @@ -199,8 +213,9 @@ def download_file(url, dest_path, headers=None, show_progress=True):
return # Exit function if successful

except urllib.error.HTTPError as e:
if e.code == HTTP_RANGE_NOT_SATISFIABLE: # "Range Not Satisfiable" error (file already downloaded)
return # No need to retry
# "Range Not Satisfiable" error (file already downloaded)
if e.code in [HTTP_RANGE_NOT_SATISFIABLE, HTTP_NOT_FOUND]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #818

raise e

except urllib.error.URLError as e:
console.error(f"Network Error: {e.reason}")
Expand Down
2 changes: 1 addition & 1 deletion ramalama/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def urlopen(self, url, headers):
try:
self.response = urllib.request.urlopen(request)
except urllib.error.HTTPError as e:
raise IOError(f"Request failed: {e.code}") from e
raise e
except urllib.error.URLError as e:
raise IOError(f"Network error: {e.reason}") from e

Expand Down
187 changes: 109 additions & 78 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import pathlib
import urllib.request
from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror
from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror, generate_sha256
from ramalama.model import Model, rm_until_substring
from ramalama.model_store import ModelRegistry, SnapshotFile

missing_huggingface = """
Optional: Huggingface models require the huggingface-cli module.
Expand All @@ -14,29 +15,30 @@

def is_huggingface_cli_available():
"""Check if huggingface-cli is available on the system."""
if available("huggingface-cli"):
return True
else:
return False
return available("huggingface-cli")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1



def fetch_checksum_from_api(url):
def fetch_checksum_from_api(organization, file):
"""Fetch the SHA-256 checksum from the model's metadata API."""
with urllib.request.urlopen(url) as response:
data = response.read().decode()
# Extract the SHA-256 checksum from the `oid sha256` line
for line in data.splitlines():
if line.startswith("oid sha256:"):
return line.split(":", 1)[1].strip()
raise ValueError("SHA-256 checksum not found in the API response.")

checksum_api_url = f"https://huggingface.co/{organization}/raw/main/{file}"
try:
with urllib.request.urlopen(checksum_api_url) as response:
data = response.read().decode()
# Extract the SHA-256 checksum from the `oid sha256` line
for line in data.splitlines():
if line.startswith("oid sha256:"):
return line.replace("oid", "").strip()
raise ValueError("SHA-256 checksum not found in the API response.")
except urllib.error.HTTPError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))
except urllib.error.URLError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))

class Huggingface(Model):
def __init__(self, model):
def __init__(self, model, store_path=""):
model = rm_until_substring(model, "hf.co/")
model = rm_until_substring(model, "://")
super().__init__(model)
self.type = "huggingface"
super().__init__(model, store_path, ModelRegistry.HUGGINGFACE)

self.hf_cli_available = is_huggingface_cli_available()

def login(self, args):
Expand All @@ -55,68 +57,97 @@ def logout(self, args):
conman_args.extend(["--token", args.token])
self.exec(conman_args, args)

def pull(self, args):
model_path = self.model_path(args)
directory_path = os.path.join(args.store, "repos", "huggingface", self.directory, self.filename)
os.makedirs(directory_path, exist_ok=True)

symlink_dir = os.path.dirname(model_path)
os.makedirs(symlink_dir, exist_ok=True)

def pull(self, debug = False):
hash, cached_files, all = self.store.get_cached_files(self.model_tag)
if all:
return self.store.get_snapshot_file_path(hash, self.filename)

# Fetch the SHA-256 checksum of model from the API and use as snapshot hash
snapshot_hash = fetch_checksum_from_api(self.store.model_organization, self.store.model_name)

blob_url = f"https://huggingface.co/{self.store.model_organization}/resolve/main"
headers = {}

files: list[SnapshotFile] = []
model_file_name = self.store.model_name
config_file_name = "config.json"
generation_config_file_name = "generation_config.json"
tokenizer_config_file_name = "tokenizer_config.json"

if model_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{model_file_name}",
header=headers,
hash=snapshot_hash,
name=model_file_name,
should_show_progress=True,
should_verify_checksum=True,
)
)
if config_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{config_file_name}",
header=headers,
hash=generate_sha256(config_file_name),
name=config_file_name,
should_show_progress=False,
should_verify_checksum=False,
required=False,
)
)
if generation_config_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{generation_config_file_name}",
header=headers,
hash=generate_sha256(generation_config_file_name),
name=generation_config_file_name,
should_show_progress=False,
should_verify_checksum=False,
required=False,
)
)
if tokenizer_config_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{tokenizer_config_file_name}",
header=headers,
hash=generate_sha256(tokenizer_config_file_name),
name=tokenizer_config_file_name,
should_show_progress=False,
should_verify_checksum=False,
required=False,
)
)

try:
return self.url_pull(args, model_path, directory_path)
self.store.new_snapshot(self.model_tag, snapshot_hash, files)
except (urllib.error.HTTPError, urllib.error.URLError, KeyError) as e:
if self.hf_cli_available:
return self.hf_pull(args, model_path, directory_path)
perror("URL pull failed and huggingface-cli not available")
raise KeyError(f"Failed to pull model: {str(e)}")

def hf_pull(self, args, model_path, directory_path):
conman_args = ["huggingface-cli", "download", "--local-dir", directory_path, self.model]
run_cmd(conman_args, debug=args.debug)

relative_target_path = os.path.relpath(directory_path, start=os.path.dirname(model_path))
pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path

def url_pull(self, args, model_path, directory_path):
# Fetch the SHA-256 checksum from the API
checksum_api_url = f"https://huggingface.co/{self.directory}/raw/main/{self.filename}"
try:
sha256_checksum = fetch_checksum_from_api(checksum_api_url)
except urllib.error.HTTPError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))
except urllib.error.URLError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))

target_path = os.path.join(directory_path, f"sha256:{sha256_checksum}")

if os.path.exists(target_path) and verify_checksum(target_path):
relative_target_path = os.path.relpath(target_path, start=os.path.dirname(model_path))
if not self.check_valid_model_path(relative_target_path, model_path):
pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path

# Download the model file to the target path
url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}"
download_file(url, target_path, headers={}, show_progress=True)
if not verify_checksum(target_path):
print(f"Checksum mismatch for {target_path}, retrying download...")
os.remove(target_path)
download_file(url, target_path, headers={}, show_progress=True)
if not verify_checksum(target_path):
raise ValueError(f"Checksum verification failed for {target_path}")

relative_target_path = os.path.relpath(target_path, start=os.path.dirname(model_path))
if self.check_valid_model_path(relative_target_path, model_path):
# Symlink is already correct, no need to update it
return model_path

pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path
if not self.hf_cli_available:
perror("URL pull failed and huggingface-cli not available")
raise KeyError(f"Failed to pull model: {str(e)}")

model_prefix = ""
if self.store.model_organization != "":
model_prefix = f"{self.store.model_organization}/"

self.store.prepare_new_snapshot(self.model_tag, snapshot_hash, files)
for file in files:
model = model_prefix + file
conman_args = ["huggingface-cli", "download", "--local-dir", self.store.blob_directory, model]
if run_cmd(conman_args, debug=debug) != 0 and not file.required:
continue

file_hash = generate_sha256(file)
blob_path = os.path.join(self.store.blob_directory, file_hash)
os.rename(src=os.path.join(self.store.blob_directory, model), dst=blob_path)

relative_target_path = os.path.relpath(blob_path, start=self.store.get_snapshot_directory(snapshot_hash))
os.symlink(relative_target_path, self.store.get_snapshot_file_path(snapshot_hash, file.name))

return self.store.get_snapshot_file_path(snapshot_hash, model_file_name)

def push(self, source, args):
if not self.hf_cli_available:
Expand Down
9 changes: 8 additions & 1 deletion ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ramalama.common import MNT_DIR, MNT_FILE
from ramalama.model_inspect import GGUFModelInfo, ModelInfoBase
from ramalama.gguf_parser import GGUFInfoParser
from ramalama.model_store import ModelStore

MODEL_TYPES = ["file", "https", "http", "oci", "huggingface", "hf", "ollama"]

Expand All @@ -41,12 +42,18 @@ class Model:
model = ""
type = "Model"

def __init__(self, model):
def __init__(self, model, store_path="", model_registry=""):
self.model = model
self.model_tag = "latest"
if ":" in model:
self.model, self.model_tag = model.split(":", 1)

split = self.model.rsplit("/", 1)
self.directory = split[0] if len(split) > 1 else ""
self.filename = split[1] if len(split) > 1 else split[0]

self.store = ModelStore(store_path, self.filename, self.directory, model_registry)

def login(self, args):
raise NotImplementedError(f"ramalama login for {self.type} not implemented")

Expand Down
Loading
Loading