How to build a NN model to be used in NXTfusion

As you can see from the examples in the examples/ folder, in order to perform inference over an ERgraph it is necessary to pass a NN object (t.nn.Module) to NXTfusion.NXmultiRelSide.NNwrapper.fit.

Our original idea was to automatically build a model suitable for each specific NXTfusion.NXTfusion.ERgraph, but, while developing the library, some considerations made us realize that this was not the best solution. First, NN are designed to be customizable and flexible, why restricting the users to our choices? Second, the entire idea of NXTfusion is to allow inference over totally arbitrary ER graphs, why restricting the most important part of the inference, namely the NN model that is actually trained to factorize the graph?

We thus opted for providing a skeleton class NXTfusion.NXmodels.NXmodelProto that contains a prototypical model that could be used in the NXTfusion.NXmultiRelSide.NNwrapper. It is barely an interface, but, alongside with this explanation and the NN models inherited from it in the examples folder we hope it’s enough.

NN model for single matrix factorization

In examples/example1.py and we perform inference over an ERgraph with 1 Relation between 2 Entities (matrix factorization problem).

In order to do so we propose the following simple model.

class example1Model(NXmodelProto):
   def __init__(self, ERG, name):
       super(example1Model, self).__init__()
       self.name = name
       ##########DEFINE NN HERE##############
       protEmbLen = ERG["prot-drug"]["lenDomain1"]
       drugEmbLen = ERG["prot-drug"]["lenDomain2"]
       PROT_LATENT_SIZE = 10
       DRUG_LATENT_SIZE = 20
       ACTIVATION = t.nn.Tanh
       self.protEmb = t.nn.Embedding(protEmbLen, PROT_LATENT_SIZE)
       self.protHid = t.nn.Sequential(t.nn.Linear(PROT_LATENT_SIZE, 10), t.nn.LayerNorm(10), ACTIVATION())

       self.drugEmb = t.nn.Embedding(drugEmbLen, DRUG_LATENT_SIZE)
       self.drugHid = t.nn.Sequential(t.nn.Linear(DRUG_LATENT_SIZE, 20), t.nn.LayerNorm(20), ACTIVATION())
       self.biProtDrug = t.nn.Bilinear(10, 20, 10)
       self.outProtDrug = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,1))
       self.apply(self.init_weights)

The trainable latent variables are represented by the protEmb and drugEmb, which are t.nn.Embedding objects. The embeddings are processed by the specific protHid and drugHid hidden layer. These layers are then joined (effectively performing the factorization), by the biProtDrug bilinear layer, which is followed by the outProtDrug final layer, which outputs the final prediction.

The names of these submodules are intended to be as familiar as possible with respect to the Entities and Relations initialized in the main of examples/example1.py.

The forward method helps understanding how these submodules are arranged. They basically connect the protEmb and drugEmb latent variables (embeddings) into making a non-linear final prediction of the cells of the target matrix.

def forward(self, relName, i1, i2, s1=None, s2=None):
    if relName == "prot-drug":
            u = self.protEmb(i1)
            v = self.drugEmb(i2)
            u = self.protHid(u).squeeze()
            v = self.drugHid(v).squeeze()
            o = self.biProtDrug(u, v)
            o = self.outProtDrug(o)
            return o

In order to make the parameters of the models (e.g. latent sizes, etc.) less dependent on magic numbers, since the NXTfusion.NXmodels.NXmodelProto class takes as input the entire ERgraph, it is possible to call by name every NXTfusion.NXTfusion.Relation and NXTfusion.NXTfusion.MetaRelation in order to automatically fetch information such as the expected number of objects in each NXTfusion.NXTfusion.Entity, as shown here.

protEmbLen = ERG["prot-drug"]["lenDomain1"]
drugEmbLen = ERG["prot-drug"]["lenDomain2"]

A NN for tensor factorization

As shown in examples/example2.py, if the model needs to model multiple NXTfusion.NXTfusion.Relation between two NXTfusion.NXTfusion.Entity, once the submodules are defined for a single relation, is sufficient to increase the number of output neuronsin the outProtDrug final layer. In this case there are 3 relations to be reconstructed (predicted) and indeed there are 3 output neurons.

self.outProtDrug = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,3))
def forward(self, relName, i1, i2, s1=None, s2=None):
       if relName == "prot-drug":
               u = self.protEmb(i1)
               v = self.drugEmb(i2)
               u = self.protHid(u).squeeze()
               v = self.drugHid(v).squeeze()
               o = self.biProtDrug(u, v)
               o = self.outProtDrug(o)
               return o

A NN for inference over arbitrary ER graphs

When the NN model must be able to predict mutiple NXTfusion.NXTfusion.MetaRelation involving multiple NXTfusion.NXTfusion.Entity (an arbitrarily connected ERgraph).

In examples/example3.py we show such a NN model. We define the embedding, entity-specific hidden (hid) and bilinear+output layer for 2 NXTfusion.NXTfusion.MetaRelation among 3 NXTfusion.NXTfusion.Entity.

class example3Model(NXmodelProto):
     def __init__(self, ERG, name):
             super(example3Model, self).__init__()
             self.name = name
             ##########DEFINE NN HERE##############
             protEmbLen = ERG["prot-drug"]["lenDomain1"]
             drugEmbLen = ERG["prot-drug"]["lenDomain2"]
             domainEmbLen = ERG["prot-domain"]["lenDomain2"]
             PROT_LATENT_SIZE = 10
             DOMAIN_LATENT_SIZE = 10
             DRUG_LATENT_SIZE = 20
             ACTIVATION = t.nn.Tanh
             self.protEmb = t.nn.Embedding(protEmbLen, PROT_LATENT_SIZE)
             self.protHid = t.nn.Sequential(t.nn.Linear(PROT_LATENT_SIZE, 10), t.nn.LayerNorm(10), ACTIVATION())

             self.drugEmb = t.nn.Embedding(drugEmbLen, DRUG_LATENT_SIZE)
             self.drugHid = t.nn.Sequential(t.nn.Linear(DRUG_LATENT_SIZE, 20), t.nn.LayerNorm(20), ACTIVATION())
             self.biProtDrug = t.nn.Bilinear(10, 20, 10)
             self.outProtDrug = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,1))

             self.domainEmb = t.nn.Embedding(domainEmbLen, DOMAIN_LATENT_SIZE)
             self.domainHid = t.nn.Sequential(t.nn.Linear(DOMAIN_LATENT_SIZE, 20), t.nn.LayerNorm(20), ACTIVATION())
             self.biProtDomain = t.nn.Bilinear(10, 20, 10)
             self.outProtDomain = t.nn.Sequential( t.nn.LayerNorm(10), ACTIVATION(), t.nn.Dropout(0.1), t.nn.Linear(10,1))

             self.apply(self.init_weights)

Besides the initializations, the most important part to understand is in the forward method. The NXTfusion.NXmultiRelSide.NNwrapper class will call by name the forward to predict each NXTfusion.NXTfusion.MetaRelation in the NXTfusion.NXTfusion.ERgraph, and to do wo it will use the argument relName.

The NNwrapper thus uses the specific name of each NXTfusion.NXTfusion.MetaRelation to tell the forward which branch of the NN must be run (each branch corresponds to a NXTfusion.NXTfusion.MetaRelation, as explained here https://doi.org/10.1093/bioinformatics/btab09).

def forward(self, relName, i1, i2, s1=None, s2=None):
     if relName == "prot-drug":
             u = self.protEmb(i1)
             v = self.drugEmb(i2)
             u = self.protHid(u).squeeze()
             v = self.drugHid(v).squeeze()
             o = self.biProtDrug(u, v)
             o = self.outProtDrug(o)
     if relName == "prot-domain":
             u = self.protEmb(i1)
             v = self.domainEmb(i2)
             u = self.protHid(u).squeeze()
             v = self.domainHid(v).squeeze()
             o = self.biProtDomain(u, v)
             o = self.outProtDomain(o)
     return o

It is thus crucial to build a forward specifying the different branches that the computation of each NXTfusion.NXTfusion.MetaRelation needs to run in order to obtain the final predictions.

Further reading

A more rigorous and theoretical description of the intuitiion behind the models shown in the examples/ scripts can be found in the original publication https://doi.org/10.1093/bioinformatics/btab09.