Modular Training and Evaluation of Neural Networks
Copyright (c) 2022, Benjamin Kaminow
Building models should be done using the mtenn.config
API.
A small example for a SchNet model is shown below, but more details for SchNet and other models can be found in the respective class definitions.
We will construct a SchNet model with default parameters and a delta G strategy for combining our complex, protein, and ligand representations. We will leave our predictions in the returned implicit kT units (ie no Readout block).
from mtenn.config import SchNetModelConfig
# Create the config using all default parameters (which includes the delta G strategy)
model_config = SchNetModelConfig()
# Build the actual pytorch model
model = model.build()
The input passed to this model should be a dict
with the following keys (based on the underlying model):
SchNet
z
: Tensor of atomic number for each atom, shape of(n,)
pos
: Tensor of coordinates for each atom, shape of(n,3)
E3NN
x
: Tensor of one-hot encodings of element for each atom, shape of(n,one_hot_length)
pos
: Tensor of coordinates for each atom, shape of(n,3)
z
: Tensor of bool labels of whether each atom is a protein atom (False
) or ligand atom (True
), shape of(n,)
GAT
g
: DGL graph object
The prediction can then be generated simply with:
import torch
# Using random data just for demonstration purposes
pose = {"z": torch.randint(low=1, high=17, size=(100,)), "pos": torch.rand((100, 3))}
pred = model(pose)
mtenn
is now on conda-forge
! To install, simply run
mamba install -c conda-forge mtenn
Project based on the Computational Molecular Science Python Cookiecutter version 1.6.