How to build a NN model to be used in NXTfusion¶
As you can see from the examples in the examples/
folder, in order to perform inference over an ERgraph it is necessary to pass a NN object (t.nn.Module) to NXTfusion.NXmultiRelSide.NNwrapper
.fit.
Our original idea was to automatically build a model suitable for each specific NXTfusion.NXTfusion.ERgraph
, but, while developing the library, some considerations made us realize that this was not the best solution. First, NN are designed to be customizable and flexible, why restricting the users to our choices? Second, the entire idea of NXTfusion is to allow inference over totally arbitrary ER graphs, why restricting the most important part of the inference, namely the NN model that is actually trained to factorize the graph?
We thus opted for providing a skeleton class NXTfusion.NXmodels.NXmodelProto
that contains a prototypical model that could be used in the NXTfusion.NXmultiRelSide.NNwrapper
. It is barely an interface, but, alongside with this explanation and the NN models inherited from it in the examples folder we hope it’s enough.
NN model for single matrix factorization¶
In examples/example1.py
and we perform inference over an ERgraph with 1 Relation between 2 Entities (matrix factorization problem).
In order to do so we propose the following simple model.
class example1Model(NXmodelProto):
def __init__(self, ERG, name):
super(example1Model, self).__init__()
self.name = name
##########DEFINE NN HERE##############
protEmbLen = ERG["prot-drug"]["lenDomain1"]
drugEmbLen = ERG["prot-drug"]["lenDomain2"]
PROT_LATENT_SIZE = 10
DRUG_LATENT_SIZE = 20
ACTIVATION = t.nn.Tanh
self.protEmb = t.nn.Embedding(protEmbLen, PROT_LATENT_SIZE)
self.protHid = t.nn.Sequential(t.nn.Linear(PROT_LATENT_SIZE, 10), t.nn.LayerNorm(10), ACTIVATION())
self.drugEmb = t.nn.Embedding(drugEmbLen, DRUG_LATENT_SIZE)
self.drugHid = t.nn.Sequential(t.nn.Linear(DRUG_LATENT_SIZE, 20), t.nn.LayerNorm(20), ACTIVATION())
self.biProtDrug = t.nn.Bilinear(10, 20, 10)
self.outProtDrug = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,1))
self.apply(self.init_weights)
The trainable latent variables are represented by the protEmb
and drugEmb
, which are t.nn.Embedding
objects.
The embeddings are processed by the specific protHid
and drugHid
hidden layer.
These layers are then joined (effectively performing the factorization), by the biProtDrug
bilinear layer, which is followed by the outProtDrug
final layer, which outputs the final prediction.
The names of these submodules are intended to be as familiar as possible with respect to the Entities and Relations initialized in the main of examples/example1.py
.
The forward method helps understanding how these submodules are arranged. They basically connect the protEmb and drugEmb latent variables (embeddings) into making a non-linear final prediction of the cells of the target matrix.
def forward(self, relName, i1, i2, s1=None, s2=None):
if relName == "prot-drug":
u = self.protEmb(i1)
v = self.drugEmb(i2)
u = self.protHid(u).squeeze()
v = self.drugHid(v).squeeze()
o = self.biProtDrug(u, v)
o = self.outProtDrug(o)
return o
In order to make the parameters of the models (e.g. latent sizes, etc.) less dependent on magic numbers, since the NXTfusion.NXmodels.NXmodelProto
class takes as input the entire ERgraph, it is possible to call by name every NXTfusion.NXTfusion.Relation
and NXTfusion.NXTfusion.MetaRelation
in order to automatically fetch information such as the expected number of objects in each NXTfusion.NXTfusion.Entity
, as shown here.
protEmbLen = ERG["prot-drug"]["lenDomain1"]
drugEmbLen = ERG["prot-drug"]["lenDomain2"]
A NN for tensor factorization¶
As shown in examples/example2.py
, if the model needs to model multiple NXTfusion.NXTfusion.Relation
between two NXTfusion.NXTfusion.Entity
, once the submodules are defined for a single relation, is sufficient to increase the number of output neuronsin the outProtDrug
final layer. In this case there are 3 relations to be reconstructed (predicted) and indeed there are 3 output neurons.
self.outProtDrug = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,3))
def forward(self, relName, i1, i2, s1=None, s2=None):
if relName == "prot-drug":
u = self.protEmb(i1)
v = self.drugEmb(i2)
u = self.protHid(u).squeeze()
v = self.drugHid(v).squeeze()
o = self.biProtDrug(u, v)
o = self.outProtDrug(o)
return o
A NN for inference over arbitrary ER graphs¶
When the NN model must be able to predict mutiple NXTfusion.NXTfusion.MetaRelation
involving multiple NXTfusion.NXTfusion.Entity
(an arbitrarily connected ERgraph).
In examples/example3.py
we show such a NN model. We define the embedding, entity-specific hidden (hid) and bilinear+output layer for 2 NXTfusion.NXTfusion.MetaRelation
among 3 NXTfusion.NXTfusion.Entity
.
class example3Model(NXmodelProto):
def __init__(self, ERG, name):
super(example3Model, self).__init__()
self.name = name
##########DEFINE NN HERE##############
protEmbLen = ERG["prot-drug"]["lenDomain1"]
drugEmbLen = ERG["prot-drug"]["lenDomain2"]
domainEmbLen = ERG["prot-domain"]["lenDomain2"]
PROT_LATENT_SIZE = 10
DOMAIN_LATENT_SIZE = 10
DRUG_LATENT_SIZE = 20
ACTIVATION = t.nn.Tanh
self.protEmb = t.nn.Embedding(protEmbLen, PROT_LATENT_SIZE)
self.protHid = t.nn.Sequential(t.nn.Linear(PROT_LATENT_SIZE, 10), t.nn.LayerNorm(10), ACTIVATION())
self.drugEmb = t.nn.Embedding(drugEmbLen, DRUG_LATENT_SIZE)
self.drugHid = t.nn.Sequential(t.nn.Linear(DRUG_LATENT_SIZE, 20), t.nn.LayerNorm(20), ACTIVATION())
self.biProtDrug = t.nn.Bilinear(10, 20, 10)
self.outProtDrug = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,1))
self.domainEmb = t.nn.Embedding(domainEmbLen, DOMAIN_LATENT_SIZE)
self.domainHid = t.nn.Sequential(t.nn.Linear(DOMAIN_LATENT_SIZE, 20), t.nn.LayerNorm(20), ACTIVATION())
self.biProtDomain = t.nn.Bilinear(10, 20, 10)
self.outProtDomain = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,1))
self.apply(self.init_weights)
Besides the initializations, the most important part to understand is in the forward method. The NXTfusion.NXmultiRelSide.NNwrapper
class will call by name the forward to predict each NXTfusion.NXTfusion.MetaRelation
in the NXTfusion.NXTfusion.ERgraph
, and to do wo it will use the argument relName
.
The NNwrapper thus uses the specific name
of each NXTfusion.NXTfusion.MetaRelation
to tell the forward which branch of the NN must be run (each branch corresponds to a NXTfusion.NXTfusion.MetaRelation
, as explained here https://doi.org/10.1093/bioinformatics/btab09).
def forward(self, relName, i1, i2, s1=None, s2=None):
if relName == "prot-drug":
u = self.protEmb(i1)
v = self.drugEmb(i2)
u = self.protHid(u).squeeze()
v = self.drugHid(v).squeeze()
o = self.biProtDrug(u, v)
o = self.outProtDrug(o)
if relName == "prot-domain":
u = self.protEmb(i1)
v = self.domainEmb(i2)
u = self.protHid(u).squeeze()
v = self.domainHid(v).squeeze()
o = self.biProtDomain(u, v)
o = self.outProtDomain(o)
return o
It is thus crucial to build a forward specifying the different branches that the computation of each NXTfusion.NXTfusion.MetaRelation
needs to run in order to obtain the final predictions.
Further reading¶
A more rigorous and theoretical description of the intuitiion behind the models shown in the examples/ scripts can be found in the original publication https://doi.org/10.1093/bioinformatics/btab09.