This commit is contained in:
mertalev
2025-03-11 18:35:21 -04:00
parent f5e44f12e1
commit ec0fa4d52b
22 changed files with 132 additions and 105 deletions

View File

@@ -0,0 +1,69 @@
import os
import tempfile
import warnings
from pathlib import Path
import torch
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
from transformers import AutoTokenizer
from .openclip import OpenCLIPModelConfig
from .openclip import to_onnx as openclip_to_onnx
from .util import get_model_path
_MCLIP_TO_OPENCLIP = {
"M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
"M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
"M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
}
def to_onnx(
model_name: str,
output_dir_visual: Path | str,
output_dir_textual: Path | str,
) -> tuple[Path, Path]:
textual_path = get_model_path(output_dir_textual)
with tempfile.TemporaryDirectory() as tmpdir:
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=os.environ.get("CACHE_DIR", tmpdir))
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
model.eval()
for param in model.parameters():
param.requires_grad_(False)
export_text_encoder(model, textual_path)
visual_path, _ = openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
assert visual_path is not None, "Visual model export failed"
return visual_path, textual_path
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
output_path = Path(output_path)
def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
embs = self.transformer(input_ids, attention_mask)[0]
embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
embs = self.LinearTransformation(embs)
return torch.nn.functional.normalize(embs, dim=-1)
# unfortunately need to monkeypatch for tracing to work here
# otherwise it hits the 2GiB protobuf serialization limit
MultilingualCLIP.forward = forward
args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
torch.onnx.export(
model,
args,
output_path.as_posix(),
input_names=["input_ids", "attention_mask"],
output_names=["embedding"],
opset_version=17,
# dynamic_axes={
# "input_ids": {0: "batch_size", 1: "sequence_length"},
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
# },
)

View File

@@ -0,0 +1,114 @@
import os
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
import open_clip
import torch
from transformers import AutoTokenizer
from .util import get_model_path, save_config
@dataclass
class OpenCLIPModelConfig:
name: str
pretrained: str
image_size: int = field(init=False)
sequence_length: int = field(init=False)
def __post_init__(self) -> None:
open_clip_cfg = open_clip.get_model_config(self.name)
if open_clip_cfg is None:
raise ValueError(f"Unknown model {self.name}")
self.image_size = open_clip_cfg["vision_cfg"]["image_size"]
self.sequence_length = open_clip_cfg["text_cfg"].get("context_length", 77)
def to_onnx(
model_cfg: OpenCLIPModelConfig,
output_dir_visual: Path | str | None = None,
output_dir_textual: Path | str | None = None,
) -> tuple[Path | None, Path | None]:
visual_path = None
textual_path = None
with tempfile.TemporaryDirectory() as tmpdir:
model = open_clip.create_model(
model_cfg.name,
pretrained=model_cfg.pretrained,
jit=False,
cache_dir=os.environ.get("CACHE_DIR", tmpdir),
require_pretrained=True,
)
text_vision_cfg = open_clip.get_model_config(model_cfg.name)
model.eval()
for param in model.parameters():
param.requires_grad_(False)
if output_dir_visual is not None:
output_dir_visual = Path(output_dir_visual)
visual_path = get_model_path(output_dir_visual)
save_config(open_clip.get_model_preprocess_cfg(model), output_dir_visual / "preprocess_cfg.json")
save_config(text_vision_cfg, output_dir_visual.parent / "config.json")
export_image_encoder(model, model_cfg, visual_path)
if output_dir_textual is not None:
output_dir_textual = Path(output_dir_textual)
textual_path = get_model_path(output_dir_textual)
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
export_text_encoder(model, model_cfg, textual_path)
return visual_path, textual_path
def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
output_path = Path(output_path)
def encode_image(image: torch.Tensor) -> torch.Tensor:
output = model.encode_image(image, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
torch.onnx.export(
traced,
args,
output_path.as_posix(),
input_names=["image"],
output_names=["embedding"],
opset_version=17,
# dynamic_axes={"image": {0: "batch_size"}},
)
def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
output_path = Path(output_path)
def encode_text(text: torch.Tensor) -> torch.Tensor:
output = model.encode_text(text, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
torch.onnx.export(
traced,
args,
output_path.as_posix(),
input_names=["text"],
output_names=["embedding"],
opset_version=17,
# dynamic_axes={"text": {0: "batch_size"}},
)

View File

@@ -0,0 +1,49 @@
from pathlib import Path
import onnx
import onnxruntime as ort
import onnxsim
def save_onnx(model: onnx.ModelProto, output_path: Path | str) -> None:
try:
onnx.save(model, output_path)
except ValueError as e:
if "The proto size is larger than the 2 GB limit." in str(e):
onnx.save(model, output_path, save_as_external_data=True, size_threshold=1_000_000)
else:
raise e
def optimize_onnxsim(model_path: Path | str, output_path: Path | str) -> None:
model_path = Path(model_path)
output_path = Path(output_path)
model = onnx.load(model_path.as_posix())
model, check = onnxsim.simplify(model)
assert check, "Simplified ONNX model could not be validated"
for file in model_path.parent.iterdir():
if file.name.startswith("Constant") or "onnx" in file.name or file.suffix == ".weight":
file.unlink()
save_onnx(model, output_path)
def optimize_ort(
model_path: Path | str,
output_path: Path | str,
level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
) -> None:
model_path = Path(model_path)
output_path = Path(output_path)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = level
sess_options.optimized_model_filepath = output_path.as_posix()
ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"], sess_options=sess_options)
def optimize(model_path: Path | str) -> None:
model_path = Path(model_path)
optimize_ort(model_path, model_path)
optimize_onnxsim(model_path, model_path)

View File

@@ -0,0 +1,15 @@
import json
from pathlib import Path
from typing import Any
def get_model_path(output_dir: Path | str) -> Path:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir / "model.onnx"
def save_config(config: Any, output_path: Path | str) -> None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
json.dump(config, output_path.open("w"))