export cli
This commit is contained in:
9
machine-learning/export/models/constants.py
Normal file
9
machine-learning/export/models/constants.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from export.models.openclip import OpenCLIPModelConfig
|
||||
|
||||
|
||||
MCLIP_TO_OPENCLIP = {
|
||||
"XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
|
||||
"XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
|
||||
"LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
||||
"XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
||||
}
|
||||
@@ -1,22 +1,15 @@
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from export.models.constants import MCLIP_TO_OPENCLIP
|
||||
|
||||
import torch
|
||||
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .openclip import OpenCLIPModelConfig
|
||||
from .openclip import to_onnx as openclip_to_onnx
|
||||
from .optimize import optimize
|
||||
from .util import get_model_path
|
||||
|
||||
_MCLIP_TO_OPENCLIP = {
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
|
||||
"M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
||||
}
|
||||
from .util import get_model_path, clean_name
|
||||
|
||||
|
||||
def to_onnx(
|
||||
@@ -33,7 +26,7 @@ def to_onnx(
|
||||
param.requires_grad_(False)
|
||||
|
||||
export_text_encoder(model, textual_path)
|
||||
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
|
||||
openclip_to_onnx(MCLIP_TO_OPENCLIP[clean_name(model_name)], output_dir_visual)
|
||||
optimize(textual_path)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
_clean_name = str.maketrans(":\\/", "___", ".")
|
||||
|
||||
|
||||
def get_model_path(output_dir: Path | str) -> Path:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -13,3 +16,7 @@ def save_config(config: Any, output_path: Path | str) -> None:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
json.dump(config, output_path.open("w"))
|
||||
|
||||
|
||||
def clean_name(model_name: str) -> str:
|
||||
return model_name.split("/")[-1].translate(_clean_name)
|
||||
|
||||
Reference in New Issue
Block a user