export cli

This commit is contained in:
mertalev
2023-11-12 18:01:12 -05:00
parent 069a32dcdb
commit ae80def7f2
5 changed files with 135 additions and 71 deletions

View File

@@ -1,76 +1,131 @@
from enum import StrEnum
import gc
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from huggingface_hub import create_repo, login, upload_folder
from models import mclip, openclip
from huggingface_hub import create_repo, upload_folder
from export.models import mclip, openclip, insightface
from export.models.util import clean_name
from rich.progress import Progress
import typer
models = [
"RN50::openai",
"RN50::yfcc15m",
"RN50::cc12m",
"RN101::openai",
"RN101::yfcc15m",
"RN50x4::openai",
"RN50x16::openai",
"RN50x64::openai",
"ViT-B-32::openai",
"ViT-B-32::laion2b_e16",
"ViT-B-32::laion400m_e31",
"ViT-B-32::laion400m_e32",
"ViT-B-32::laion2b-s34b-b79k",
"ViT-B-16::openai",
"ViT-B-16::laion400m_e31",
"ViT-B-16::laion400m_e32",
"ViT-B-16-plus-240::laion400m_e31",
"ViT-B-16-plus-240::laion400m_e32",
"ViT-L-14::openai",
"ViT-L-14::laion400m_e31",
"ViT-L-14::laion400m_e32",
"ViT-L-14::laion2b-s32b-b82k",
"ViT-L-14-336::openai",
"ViT-H-14::laion2b-s32b-b79k",
"ViT-g-14::laion2b-s12b-b42k",
"M-CLIP/LABSE-Vit-L-14",
"M-CLIP/XLM-Roberta-Large-Vit-B-32",
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
"M-CLIP/XLM-Roberta-Large-Vit-L-14",
]
login(token=os.environ["HF_AUTH_TOKEN"])
app = typer.Typer()
with Progress() as progress:
task1 = progress.add_task("[green]Exporting models...", total=len(models))
task2 = progress.add_task("[yellow]Uploading models...", total=len(models))
with TemporaryDirectory() as tmp:
tmpdir = Path(tmp)
for model in models:
model_name = model.split("/")[-1].replace("::", "__")
config_path = tmpdir / model_name / "config.json"
class ModelLibrary(StrEnum):
MCLIP = "mclip"
OPENCLIP = "openclip"
INSIGHTFACE = "insightface"
def upload() -> None:
progress.update(task2, description=f"[yellow]Uploading {model_name}")
repo_id = f"immich-app/{model_name}"
create_repo(repo_id, exist_ok=True)
upload_folder(repo_id=repo_id, folder_path=tmpdir / model_name)
progress.update(task2, advance=1)
def _export(model_name: str, library: ModelLibrary, export_dir: Path) -> None:
visual_dir = export_dir / "visual"
textual_dir = export_dir / "textual"
match library:
case ModelLibrary.MCLIP:
insightface.to_onnx(model_name, visual_dir, textual_dir)
case ModelLibrary.OPENCLIP:
mclip.to_onnx(model_name, visual_dir, textual_dir)
case ModelLibrary.INSIGHTFACE:
name, _, pretrained = model_name.partition("__")
openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)
def export() -> None:
progress.update(task1, description=f"[green]Exporting {model_name}")
visual_dir = tmpdir / model_name / "visual"
textual_dir = tmpdir / model_name / "textual"
if model.startswith("M-CLIP"):
mclip.to_onnx(model, visual_dir, textual_dir)
else:
name, _, pretrained = model_name.partition("__")
openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)
gc.collect()
progress.update(task1, advance=1)
gc.collect()
export()
upload()
def _upload(repo_id: str, upload_dir: Path, auth_token: str | None = os.environ.get("HF_AUTH_TOKEN", None)) -> None:
create_repo(repo_id, exist_ok=True, token=auth_token)
upload_folder(repo_id=repo_id, folder_path=upload_dir, token=auth_token)
@app.command()
def export(
models: list[str] = typer.Argument(
..., help="The model(s) to be exported. Model names should be the same as used in the associated library."
),
library: ModelLibrary = typer.Option(
..., "--library", "-l", help="The library associated with the models to be exported."
),
output_dir: Optional[Path] = typer.Option(
None,
"--output-dir",
"-o",
help="Directory where exported models will be stored. Defaults to a temporary directory.",
),
should_upload: bool = typer.Option(False, "--upload", "-u", help="Whether to upload the exported models."),
auth_token: Optional[str] = typer.Option(
os.environ.get("HF_AUTH_TOKEN", None),
"--auth_token",
"-t",
help="If uploading models to Hugging Face, the auth token of the user or organisation.",
),
repo_prefix: str = typer.Option(
"immich-app",
"--repo_prefix",
"-p",
help="If uploading models to Hugging Face, the prefix to put before the model name. Can be a username or organisation.",
),
) -> None:
if not models:
raise ValueError("No models specified")
with Progress() as progress:
task1 = progress.add_task("[green]Exporting model(s)...", total=len(models))
with TemporaryDirectory() as tmp:
output_dir = output_dir if output_dir else Path(tmp)
for model_name in models:
cleaned_name = clean_name(model_name)
model_dir = output_dir / cleaned_name
progress.update(task1, description=f"[green]Exporting {cleaned_name}")
_export(model_name, library, model_dir)
progress.update(task1, advance=1, description=f"[green]Exported {cleaned_name}")
if should_upload:
upload(models, output_dir, auth_token, repo_prefix)
@app.command()
def upload(
models: list[str] = typer.Argument(
..., help="The model(s) to be uploaded. Model names should be the same as used in the associated library."
),
output_dir: Optional[Path] = typer.Option(
None,
"--output-dir",
"-o",
help="Directory where exported models will be stored. Defaults to a temporary directory.",
),
auth_token: Optional[str] = typer.Option(
os.environ.get("HF_AUTH_TOKEN", None),
"--auth_token",
"-t",
help="The Hugging Face auth token of the user or organisation.",
),
repo_prefix: str = typer.Option(
"immich-app",
"--repo_prefix",
"-p",
help="The name to put before the model name to form the Hugging Face repo name. Can be a username or organisation.",
),
) -> None:
if not models:
raise ValueError("No models specified")
with Progress() as progress:
task2 = progress.add_task("[yellow]Uploading models...", total=len(models))
for model_name in models:
cleaned_name = clean_name(model_name)
repo_id = f"{repo_prefix}/{cleaned_name}"
model_dir = output_dir / cleaned_name
progress.update(task2, description=f"[yellow]Uploading {cleaned_name}")
_upload(repo_id, model_dir, auth_token)
progress.update(task2, advance=1, description=f"[yellow]Uploaded {cleaned_name}")
if __name__ == "__main__":
app()