"""Semantic-Object-Context modality model."""
from typing import Dict, Optional
import torch
import torch_tensorrt
import torch.nn.functional as F
from mlp_mixer_pytorch import MLPMixer
from torch import Tensor, nn
[docs]
class SOCModel(nn.Module):
"""Semantic-Object-Context modality base model class."""
def __init__(self, num_classes: int, num_objects: int, embeddings_size: Optional[int] = 256) -> None:
"""Semantic-Object-Context modality model.
Args:
num_classes (int): number of classes
num_objects (int): number of objects
embeddings_size (int): size of output embeddings
Returns:
None
"""
super().__init__()
# Input shape (batch_size, num_classes, num_objects, 3 (coords))
self.num_classes = num_classes
self.num_objects = num_objects
[docs]
def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""Forward pass.
Args:
x (Tensor): input batch
Returns:
Dict[str, Tensor]: output dictionary
"""
raise NotImplementedError
[docs]
class SOCMLP(SOCModel):
"""Semantic-Object-Context modality model."""
def __init__(self, num_classes: int, num_objects: int, embeddings_size: Optional[int] = 256) -> None:
"""Semantic-Object-Context modality model.
Args:
num_classes (int): number of classes
num_objects (int): number of objects
embeddings_size (int): size of embeddings
Returns:
None
"""
super().__init__()
# Input shape (batch_size, num_classes, num_objects, 3 (coords))
self.num_classes = num_classes
self.num_objects = num_objects
self.fc1 = nn.Linear(num_classes * num_objects * 3, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, embeddings_size)
[docs]
def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""Forward pass.
Args:
x (Tensor): input batch
Returns:
torch.Tensor: output tensor of shape (batch_size, embeddings_size)
"""
batch_size = x.shape[0]
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
descriptor = self.fc3(x)
out_dict: Dict[str, Tensor] = {"final_descriptor": descriptor}
return out_dict
[docs]
class SOCMLPMixer(SOCModel):
"""Semantic-Object-Context modality model based on MLP Mixer .
Kind of Attention-layer build on top of MLPs.
Original paper: https://arxiv.org/abs/2105.01601
implementation: https://github.com/lucidrains/mlp-mixer-pytorch
"""
def __init__(
self,
num_classes: int,
num_objects: int,
patch_size: int = 1,
hidden_dim: int = 64,
depth: int = 3,
embeddings_size: int = 256,
) -> None:
"""Semantic-Object-Context modality model based on MLP Mixer .
Kind of Attention-layer build on top of MLPs.
Original paper: https://arxiv.org/abs/2105.01601
implementation: https://github.com/lucidrains/mlp-mixer-pytorch
Args:
num_classes (int): number of classes
num_objects (int): number of objects
patch_size (int): patch size
hidden_dim (int): hidden dimension
depth (int): depth
embeddings_size (int): size of embeddings
Returns:
None
"""
super(SOCMLPMixer, self).__init__(num_classes, num_objects)
self.mlp_mixer = MLPMixer(
image_size=(num_classes, 1),
channels=num_objects * 3, # Assuming each of the K triplets is a "channel"
patch_size=patch_size, # Should be divider of N
dim=hidden_dim,
depth=depth,
num_classes=embeddings_size, # This will be projected down to 256 by the custom network
)
# Define a fully connected layer that takes the output of the MLP Mixer and
# projects it down to the desired embedding size (256 in this case)
self.fc = nn.Linear(embeddings_size, embeddings_size)
[docs]
def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""Forward pass.
Args:
x (Tensor): input batch
Returns:
Dict[str, Tensor] : output dictionary with "final_descriptor" key containing the output tensor
"""
# Reshape input to be compatible with the MLP Mixer, which expects an "image" tensor
# Assuming the input x is of shape (batch_size, N, K, 3)
batch_size = x.shape[0]
# Flatten the last two dimensions and treat them as channels (K*3)
x_reshaped = x.view(batch_size, self.num_classes, self.num_objects * 3)
x_permuted = x_reshaped.permute(0, 2, 1)
x_permuted = x_permuted.unsqueeze(3) # Add a height dimension
# Pass the reshaped input through the MLP Mixer
x_mixed = self.mlp_mixer(x_permuted)
# Flatten the output to pass through the fully connected layer
x_flat = x_mixed.view(batch_size, -1)
descriptor = self.fc(x_flat)
out_dict: Dict[str, Tensor] = {"final_descriptor": descriptor}
return out_dict
[docs]
class SOCMLPMixerModel(nn.Module):
def __init__(self, model, forward_type="fp32"):
super().__init__()
self.model = model
self.forward_type = forward_type
if forward_type.startswith("trt_fp32"):
print(f"WARNING - {forward_type} mode is only for inference on cuda!")
self.trt_model = None
[docs]
def forward(self, batch):
value = batch["soc"]
if self.forward_type == "trt_fp32":
if not self.trt_model:
# Enabled precision for TensorRT optimization
enabled_precisions = {torch.float32}
# Whether to print verbose logs
debug = False
# Workspace size for TensorRT
workspace_size = 20 << 30
# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 7
# Operations to Run in Torch, regardless of converter support
torch_executed_ops = {}
# Build and compile the model with torch.compile, using Torch-TensorRT backend
self.trt_model = torch_tensorrt.compile(
self.model,
ir="torch_compile",
inputs=[value.contiguous()],
enabled_precisions=enabled_precisions,
debug=debug,
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
)
return self.trt_model(value.contiguous())
else:
return self.model(value)