Skip to content

Commit

Permalink
fix(sim/embed): add load_spacy to download models on demand
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkolenz committed Feb 6, 2025
1 parent 9951d58 commit babc594
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 8 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"polars>=1,<2",
"pydantic>=2,<3",
"pyyaml>=6,<7",
"rich>=13,<14",
"rtoml>=0.12,<1",
"scipy>=1,<2",
"xmltodict>=0.13,<1",
Expand All @@ -57,7 +58,7 @@ api = [
"python-multipart>=0.0.15,<1",
"uvicorn[standard]>=0.30,<1",
]
cli = ["rich>=13,<14", "typer>=0.9,<1"]
cli = ["typer>=0.9,<1"]
eval = ["ranx>=0.3,<1"]
graphs = ["networkx>=3,<4", "rustworkx>=0.15,<1"]
llm = [
Expand Down
2 changes: 2 additions & 0 deletions src/cbrkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from . import (
adapt,
constants,
cycle,
dumpers,
eval,
Expand All @@ -37,6 +38,7 @@
"sim",
"synthesis",
"typing",
"constants",
]

logging.getLogger(__name__).addHandler(logging.NullHandler())
Expand Down
4 changes: 4 additions & 0 deletions src/cbrkit/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from os import getenv
from pathlib import Path

CACHE_DIR = Path(getenv("CBRKIT_CACHE_DIR", Path.home() / ".cache" / "cbrkit"))
81 changes: 78 additions & 3 deletions src/cbrkit/sim/embed.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
import itertools
from collections.abc import MutableMapping, Sequence
from collections.abc import Iterator, MutableMapping, Sequence
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, cast, override

import numpy as np
from scipy.spatial.distance import cosine as scipy_cosine

from ..constants import CACHE_DIR
from ..helpers import (
batchify_conversion,
batchify_sim,
Expand Down Expand Up @@ -40,6 +42,7 @@
"cache",
"concat",
"spacy",
"load_spacy",
"sentence_transformers",
"openai",
"ollama",
Expand Down Expand Up @@ -226,9 +229,81 @@ def __call__(self, texts: Sequence[V]) -> Sequence[NumpyArray]:


with optional_dependencies():
from spacy import load as spacy_load
import spacy as spacylib
from spacy.cli.download import get_latest_version, get_model_filename
from spacy.language import Language

def load_spacy(name: str | None, cache_dir: Path = CACHE_DIR) -> Language:
import tarfile
import urllib.request

from rich.progress import Progress, TaskID

@dataclass(slots=True)
class ProgressHook(AbstractContextManager):
description: str
progress: Progress = field(default_factory=Progress, init=False)
task: TaskID | None = field(default=None, init=False)

def __enter__(self):
self.progress.start()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.progress.stop()

def __call__(self, block_num: int, block_size: int, total_size: int):
if self.task is None:
self.task = self.progress.add_task(
self.description, total=total_size
)

downloaded = block_num * block_size

if downloaded < total_size:
self.progress.update(self.task, completed=downloaded)

if self.progress.finished:
self.task = None

def tarfile_members(
tf: tarfile.TarFile, prefix: str
) -> Iterator[tarfile.TarInfo]:
prefix_len = len(prefix)

for member in tf.getmembers():
if member.path.startswith(prefix):
member.path = member.path[prefix_len:]

yield member

if not name:
return spacylib.blank("en")

version = get_latest_version(name)
filename = get_model_filename(name, version, sdist=True)
versioned_name = f"{name}-{version}"
cache_file = cache_dir / "spacy" / versioned_name
tmpfile = cache_file.with_suffix(".tar.gz")

if not cache_file.exists():
cache_file.parent.mkdir(parents=True, exist_ok=True)
download_url = f"{spacylib.about.__download_url__}/{filename}"

with ProgressHook(
f"Downloading '{versioned_name}' to '{cache_file.parent}'..."
) as hook:
urllib.request.urlretrieve(download_url, tmpfile, hook)

with tarfile.open(tmpfile, mode="r:gz") as tf:
member_prefix = f"{versioned_name}/{name}/{versioned_name}/"
members = tarfile_members(tf, member_prefix)
tf.extractall(path=cache_file, members=members)

tmpfile.unlink()

return spacylib.load(cache_file)

@dataclass(slots=True)
class spacy(BatchConversionFunc[str, NumpyArray], HasMetadata):
"""Semantic similarity using [spaCy](https://spacy.io/)
Expand All @@ -242,7 +317,7 @@ class spacy(BatchConversionFunc[str, NumpyArray], HasMetadata):

def __init__(self, model: str | Language):
if isinstance(model, str):
self.model = spacy_load(model)
self.model = load_spacy(model)
else:
self.model = model

Expand Down
6 changes: 2 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit babc594

Please sign in to comment.