Skip to content

Commit

Permalink
Introduced a model store
Browse files Browse the repository at this point in the history
Added a model store to standardize pulling, storing and using models
across the different repositories.

Signed-off-by: Michael Engel <[email protected]>
  • Loading branch information
engelmi committed Feb 13, 2025
1 parent 173cae3 commit c3fc4fd
Show file tree
Hide file tree
Showing 9 changed files with 450 additions and 215 deletions.
14 changes: 7 additions & 7 deletions ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,21 +997,21 @@ def rm_cli(args):

def New(model, args):
if model.startswith("huggingface://") or model.startswith("hf://") or model.startswith("hf.co/"):
return Huggingface(model)
return Huggingface(model, args.store)
if model.startswith("ollama://") or "ollama.com/library/" in model:
return Ollama(model)
return Ollama(model, args.store)
if model.startswith("oci://") or model.startswith("docker://"):
return OCI(model, args.engine)
return OCI(model, args.engine, args.store)
if model.startswith("http://") or model.startswith("https://") or model.startswith("file://"):
return URL(model)
return URL(model, args.store)

transport = config.get("transport", "ollama")
if transport == "huggingface":
return Huggingface(model)
return Huggingface(model, args.store)
if transport == "ollama":
return Ollama(model)
return Ollama(model, args.store)
if transport == "oci":
return OCI(model, args.engine)
return OCI(model, args.engine, args.store)

raise KeyError(f'transport "{transport}" not supported. Must be oci, huggingface, or ollama.')

Expand Down
19 changes: 17 additions & 2 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

MNT_DIR = "/mnt/models"
MNT_FILE = f"{MNT_DIR}/model.file"
HTTP_NOT_FOUND = 404
HTTP_RANGE_NOT_SATISFIABLE = 416

DEFAULT_IMAGE = "quay.io/ramalama/ramalama"
Expand Down Expand Up @@ -154,6 +155,19 @@ def verify_checksum(filename):
# Compare the checksums
return sha256_hash.hexdigest() == expected_checksum

def generate_sha256(to_hash: str) -> str:
"""
Generates a sha256 for a string.
Args:
to_hash (str): The string to generate the sha256 hash for.
Returns:
str: Hex digest of the input appended to the prefix sha256:
"""
h = hashlib.new("sha256")
h.update(to_hash.encode("utf-8"))
return f"sha256:{h.hexdigest()}"

# default_image function should figure out which GPU the system uses t
# then running appropriate container image.
Expand Down Expand Up @@ -199,8 +213,9 @@ def download_file(url, dest_path, headers=None, show_progress=True):
return # Exit function if successful

except urllib.error.HTTPError as e:
if e.code == HTTP_RANGE_NOT_SATISFIABLE: # "Range Not Satisfiable" error (file already downloaded)
return # No need to retry
# "Range Not Satisfiable" error (file already downloaded)
if e.code in [HTTP_RANGE_NOT_SATISFIABLE, HTTP_NOT_FOUND]:
raise e

except urllib.error.URLError as e:
console.error(f"Network Error: {e.reason}")
Expand Down
2 changes: 1 addition & 1 deletion ramalama/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def urlopen(self, url, headers):
try:
self.response = urllib.request.urlopen(request)
except urllib.error.HTTPError as e:
raise IOError(f"Request failed: {e.code}") from e
raise e
except urllib.error.URLError as e:
raise IOError(f"Network error: {e.reason}") from e

Expand Down
187 changes: 109 additions & 78 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import pathlib
import urllib.request
from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror
from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror, generate_sha256
from ramalama.model import Model, rm_until_substring
from ramalama.model_store import ModelRegistry, SnapshotFile

missing_huggingface = """
Optional: Huggingface models require the huggingface-cli module.
Expand All @@ -14,29 +15,30 @@

def is_huggingface_cli_available():
"""Check if huggingface-cli is available on the system."""
if available("huggingface-cli"):
return True
else:
return False
return available("huggingface-cli")


def fetch_checksum_from_api(url):
def fetch_checksum_from_api(organization, file):
"""Fetch the SHA-256 checksum from the model's metadata API."""
with urllib.request.urlopen(url) as response:
data = response.read().decode()
# Extract the SHA-256 checksum from the `oid sha256` line
for line in data.splitlines():
if line.startswith("oid sha256:"):
return line.split(":", 1)[1].strip()
raise ValueError("SHA-256 checksum not found in the API response.")

checksum_api_url = f"https://huggingface.co/{organization}/raw/main/{file}"
try:
with urllib.request.urlopen(checksum_api_url) as response:
data = response.read().decode()
# Extract the SHA-256 checksum from the `oid sha256` line
for line in data.splitlines():
if line.startswith("oid sha256:"):
return line.replace("oid", "").strip()
raise ValueError("SHA-256 checksum not found in the API response.")
except urllib.error.HTTPError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))
except urllib.error.URLError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))

class Huggingface(Model):
def __init__(self, model):
def __init__(self, model, store_path=""):
model = rm_until_substring(model, "hf.co/")
model = rm_until_substring(model, "://")
super().__init__(model)
self.type = "huggingface"
super().__init__(model, store_path, ModelRegistry.HUGGINGFACE)

self.hf_cli_available = is_huggingface_cli_available()

def login(self, args):
Expand All @@ -55,68 +57,97 @@ def logout(self, args):
conman_args.extend(["--token", args.token])
self.exec(conman_args, args)

def pull(self, args):
model_path = self.model_path(args)
directory_path = os.path.join(args.store, "repos", "huggingface", self.directory, self.filename)
os.makedirs(directory_path, exist_ok=True)

symlink_dir = os.path.dirname(model_path)
os.makedirs(symlink_dir, exist_ok=True)

def pull(self, debug = False):
hash, cached_files, all = self.store.get_cached_files(self.model_tag)
if all:
return self.store.get_snapshot_file_path(hash, self.filename)

# Fetch the SHA-256 checksum of model from the API and use as snapshot hash
snapshot_hash = fetch_checksum_from_api(self.store.model_organization, self.store.model_name)

blob_url = f"https://huggingface.co/{self.store.model_organization}/resolve/main"
headers = {}

files: list[SnapshotFile] = []
model_file_name = self.store.model_name
config_file_name = "config.json"
generation_config_file_name = "generation_config.json"
tokenizer_config_file_name = "tokenizer_config.json"

if model_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{model_file_name}",
header=headers,
hash=snapshot_hash,
name=model_file_name,
should_show_progress=True,
should_verify_checksum=True,
)
)
if config_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{config_file_name}",
header=headers,
hash=generate_sha256(config_file_name),
name=config_file_name,
should_show_progress=False,
should_verify_checksum=False,
required=False,
)
)
if generation_config_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{generation_config_file_name}",
header=headers,
hash=generate_sha256(generation_config_file_name),
name=generation_config_file_name,
should_show_progress=False,
should_verify_checksum=False,
required=False,
)
)
if tokenizer_config_file_name not in cached_files:
files.append(
SnapshotFile(
url=f"{blob_url}/{tokenizer_config_file_name}",
header=headers,
hash=generate_sha256(tokenizer_config_file_name),
name=tokenizer_config_file_name,
should_show_progress=False,
should_verify_checksum=False,
required=False,
)
)

try:
return self.url_pull(args, model_path, directory_path)
self.store.new_snapshot(self.model_tag, snapshot_hash, files)
except (urllib.error.HTTPError, urllib.error.URLError, KeyError) as e:
if self.hf_cli_available:
return self.hf_pull(args, model_path, directory_path)
perror("URL pull failed and huggingface-cli not available")
raise KeyError(f"Failed to pull model: {str(e)}")

def hf_pull(self, args, model_path, directory_path):
conman_args = ["huggingface-cli", "download", "--local-dir", directory_path, self.model]
run_cmd(conman_args, debug=args.debug)

relative_target_path = os.path.relpath(directory_path, start=os.path.dirname(model_path))
pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path

def url_pull(self, args, model_path, directory_path):
# Fetch the SHA-256 checksum from the API
checksum_api_url = f"https://huggingface.co/{self.directory}/raw/main/{self.filename}"
try:
sha256_checksum = fetch_checksum_from_api(checksum_api_url)
except urllib.error.HTTPError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))
except urllib.error.URLError as e:
raise KeyError(f"failed to pull {checksum_api_url}: " + str(e).strip("'"))

target_path = os.path.join(directory_path, f"sha256:{sha256_checksum}")

if os.path.exists(target_path) and verify_checksum(target_path):
relative_target_path = os.path.relpath(target_path, start=os.path.dirname(model_path))
if not self.check_valid_model_path(relative_target_path, model_path):
pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path

# Download the model file to the target path
url = f"https://huggingface.co/{self.directory}/resolve/main/{self.filename}"
download_file(url, target_path, headers={}, show_progress=True)
if not verify_checksum(target_path):
print(f"Checksum mismatch for {target_path}, retrying download...")
os.remove(target_path)
download_file(url, target_path, headers={}, show_progress=True)
if not verify_checksum(target_path):
raise ValueError(f"Checksum verification failed for {target_path}")

relative_target_path = os.path.relpath(target_path, start=os.path.dirname(model_path))
if self.check_valid_model_path(relative_target_path, model_path):
# Symlink is already correct, no need to update it
return model_path

pathlib.Path(model_path).unlink(missing_ok=True)
os.symlink(relative_target_path, model_path)
return model_path
if not self.hf_cli_available:
perror("URL pull failed and huggingface-cli not available")
raise KeyError(f"Failed to pull model: {str(e)}")

model_prefix = ""
if self.store.model_organization != "":
model_prefix = f"{self.store.model_organization}/"

self.store.prepare_new_snapshot(self.model_tag, snapshot_hash, files)
for file in files:
model = model_prefix + file
conman_args = ["huggingface-cli", "download", "--local-dir", self.store.blob_directory, model]
if run_cmd(conman_args, debug=debug) != 0 and not file.required:
continue

file_hash = generate_sha256(file)
blob_path = os.path.join(self.store.blob_directory, file_hash)
os.rename(src=os.path.join(self.store.blob_directory, model), dst=blob_path)

relative_target_path = os.path.relpath(blob_path, start=self.store.get_snapshot_directory(snapshot_hash))
os.symlink(relative_target_path, self.store.get_snapshot_file_path(snapshot_hash, file.name))

return self.store.get_snapshot_file_path(snapshot_hash, model_file_name)

def push(self, source, args):
if not self.hf_cli_available:
Expand Down
Loading

0 comments on commit c3fc4fd

Please sign in to comment.