Source code for opr.modules.mlp

"""2-layer MLP module implementation."""
from functools import partial
from typing import Optional

from torch import Tensor, nn


[docs] class MLP(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks.""" def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: str = "gelu", bias: bool = True, drop: float = 0.0, use_conv: bool = False, ) -> None: """MLP as used in Vision Transformer, MLP-Mixer and related networks. Args: in_features (int): Number of input features. hidden_features (int, optional): Number of hidden features. Defaults to None. out_features (int, optional): Number of output features. Defaults to None. act_layer (str): Activation function. Defaults to "gelu". bias (bool): Whether to use bias. Defaults to True. drop (float): Dropout probability. Defaults to 0.0. use_conv (bool): Whether to use Conv2d instead of Linear. Defaults to False. Raises: ValueError: Unsupported activation function. """ super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias) if act_layer.lower() == "gelu": self.act = nn.GELU() elif act_layer.lower() == "relu": self.act = nn.ReLU() elif act_layer.lower() == "leakyrelu": self.act = nn.LeakyReLU() else: raise ValueError(f"Unsupported activation function: {act_layer}") self.drop1 = nn.Dropout(drop) self.fc2 = linear_layer(hidden_features, out_features, bias=bias) self.drop2 = nn.Dropout(drop)
[docs] def forward(self, x: Tensor) -> Tensor: # noqa: D102 x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x