export cli

This commit is contained in:
mertalev
2023-11-12 18:01:12 -05:00
parent 069a32dcdb
commit ae80def7f2
5 changed files with 135 additions and 71 deletions

View File

@@ -0,0 +1,9 @@
from export.models.openclip import OpenCLIPModelConfig
MCLIP_TO_OPENCLIP = {
"XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
"XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
"LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
"XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
}

View File

@@ -1,22 +1,15 @@
import tempfile
import warnings
from pathlib import Path
from export.models.constants import MCLIP_TO_OPENCLIP
import torch
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
from transformers import AutoTokenizer
from .openclip import OpenCLIPModelConfig
from .openclip import to_onnx as openclip_to_onnx
from .optimize import optimize
from .util import get_model_path
_MCLIP_TO_OPENCLIP = {
"M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
"M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
"M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
}
from .util import get_model_path, clean_name
def to_onnx(
@@ -33,7 +26,7 @@ def to_onnx(
param.requires_grad_(False)
export_text_encoder(model, textual_path)
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
openclip_to_onnx(MCLIP_TO_OPENCLIP[clean_name(model_name)], output_dir_visual)
optimize(textual_path)

View File

@@ -3,6 +3,9 @@ from pathlib import Path
from typing import Any
_clean_name = str.maketrans(":\\/", "___", ".")
def get_model_path(output_dir: Path | str) -> Path:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@@ -13,3 +16,7 @@ def save_config(config: Any, output_path: Path | str) -> None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
json.dump(config, output_path.open("w"))
def clean_name(model_name: str) -> str:
return model_name.split("/")[-1].translate(_clean_name)