项目初次搭建
This commit is contained in:
90
yolov5/utils/triton.py
Normal file
90
yolov5/utils/triton.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Utils to interact with the Triton Inference Server."""
|
||||
|
||||
import typing
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TritonRemoteModel:
|
||||
"""
|
||||
A wrapper over a model served by the Triton Inference Server.
|
||||
|
||||
It can be configured to communicate over GRPC or HTTP. It accepts Torch Tensors as input and returns them as
|
||||
outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str):
|
||||
"""
|
||||
Keyword Arguments:
|
||||
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000.
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.scheme == "grpc":
|
||||
from tritonclient.grpc import InferenceServerClient, InferInput
|
||||
|
||||
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
|
||||
model_repository = self.client.get_model_repository_index()
|
||||
self.model_name = model_repository.models[0].name
|
||||
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
|
||||
|
||||
def create_input_placeholders() -> typing.List[InferInput]:
|
||||
return [
|
||||
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
|
||||
]
|
||||
|
||||
else:
|
||||
from tritonclient.http import InferenceServerClient, InferInput
|
||||
|
||||
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
|
||||
model_repository = self.client.get_model_repository_index()
|
||||
self.model_name = model_repository[0]["name"]
|
||||
self.metadata = self.client.get_model_metadata(self.model_name)
|
||||
|
||||
def create_input_placeholders() -> typing.List[InferInput]:
|
||||
return [
|
||||
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
|
||||
]
|
||||
|
||||
self._create_input_placeholders_fn = create_input_placeholders
|
||||
|
||||
@property
|
||||
def runtime(self):
|
||||
"""Returns the model runtime."""
|
||||
return self.metadata.get("backend", self.metadata.get("platform"))
|
||||
|
||||
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
|
||||
"""
|
||||
Invokes the model.
|
||||
|
||||
Parameters can be provided via args or kwargs. args, if provided, are assumed to match the order of inputs of
|
||||
the model. kwargs are matched with the model input names.
|
||||
"""
|
||||
inputs = self._create_inputs(*args, **kwargs)
|
||||
response = self.client.infer(model_name=self.model_name, inputs=inputs)
|
||||
result = []
|
||||
for output in self.metadata["outputs"]:
|
||||
tensor = torch.as_tensor(response.as_numpy(output["name"]))
|
||||
result.append(tensor)
|
||||
return result[0] if len(result) == 1 else result
|
||||
|
||||
def _create_inputs(self, *args, **kwargs):
|
||||
"""Creates input tensors from args or kwargs, not both; raises error if none or both are provided."""
|
||||
args_len, kwargs_len = len(args), len(kwargs)
|
||||
if not args_len and not kwargs_len:
|
||||
raise RuntimeError("No inputs provided.")
|
||||
if args_len and kwargs_len:
|
||||
raise RuntimeError("Cannot specify args and kwargs at the same time")
|
||||
|
||||
placeholders = self._create_input_placeholders_fn()
|
||||
if args_len:
|
||||
if args_len != len(placeholders):
|
||||
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.")
|
||||
for input, value in zip(placeholders, args):
|
||||
input.set_data_from_numpy(value.cpu().numpy())
|
||||
else:
|
||||
for input in placeholders:
|
||||
value = kwargs[input.name]
|
||||
input.set_data_from_numpy(value.cpu().numpy())
|
||||
return placeholders
|
Reference in New Issue
Block a user