feat: preload textual model
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Callable, Iterator
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import orjson
|
||||
from fastapi import Depends, FastAPI, File, Form, HTTPException
|
||||
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||
from PIL.Image import Image
|
||||
@@ -28,6 +28,7 @@ from .schemas import (
|
||||
InferenceEntries,
|
||||
InferenceEntry,
|
||||
InferenceResponse,
|
||||
LoadModelEntry,
|
||||
MessageResponse,
|
||||
ModelFormat,
|
||||
ModelIdentity,
|
||||
@@ -124,6 +125,24 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
|
||||
raise HTTPException(422, "Invalid request format.")
|
||||
|
||||
|
||||
def get_entry(entries: str = Form()) -> LoadModelEntry:
|
||||
try:
|
||||
request: PipelineRequest = orjson.loads(entries)
|
||||
for task, types in request.items():
|
||||
for type, entry in types.items():
|
||||
parsed: LoadModelEntry = {
|
||||
"name": entry["modelName"],
|
||||
"task": task,
|
||||
"type": type,
|
||||
"options": entry.get("options", {}),
|
||||
"ttl": entry["ttl"] if "ttl" in entry else settings.ttl,
|
||||
}
|
||||
return parsed
|
||||
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
||||
log.error(f"Invalid request format: {e}")
|
||||
raise HTTPException(422, "Invalid request format.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@@ -137,6 +156,13 @@ def ping() -> str:
|
||||
return "pong"
|
||||
|
||||
|
||||
@app.post("/load", response_model=TextResponse)
|
||||
async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
|
||||
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
|
||||
model = await load(model)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/predict", dependencies=[Depends(update_state)])
|
||||
async def predict(
|
||||
entries: InferenceEntries = Depends(get_entries),
|
||||
|
||||
@@ -109,6 +109,17 @@ class InferenceEntry(TypedDict):
|
||||
options: dict[str, Any]
|
||||
|
||||
|
||||
class LoadModelEntry(InferenceEntry):
|
||||
ttl: int
|
||||
|
||||
def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int):
|
||||
super().__init__(name=name, task=task, type=type, options=options)
|
||||
|
||||
if ttl <= 0:
|
||||
raise ValueError("ttl must be a positive integer")
|
||||
self.ttl = ttl
|
||||
|
||||
|
||||
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user