NXTfusion.NXmultiRelSide module

class NXTfusion.NXmultiRelSide.NNwrapper(model, dev, ignore_index, initialEpoch=0, nworkers=0)[source]

Bases: object

Class that wraps a t.nn.Module (pytorch module) and uses scikit-learn-like methods such as .fit() and .predict() to train and test it.

__init__(model, dev, ignore_index, initialEpoch=0, nworkers=0)[source]

Constructor for the NNWrapper class, which facilitates and standardizes the training of pytorch neural networks.

Parameters
  • model (t.nn.Module) – The pytorch Neural Network that should be trained or tested.

  • def (t.device) – The device on which the model should run. E.g. t.device(“cuda”) or t.device(“cpu:0”)

  • ignore_index (int) – The ignore index value that will be used to mark “missing values” and “N/A” on partially observed matrices, in order to let the corresponding loss ignore those instances.

computeLosses(y, yp, losses, relationData, weightRelations)[source]

This function computes the losses for the entire ER graph, by iterating through them. Used internally.

Parameters
  • y (t.tensor) – Pytorch tensor containing the labels

  • yp (t.tensor) – Pytorch tensor containing the predictions

  • losses (list) – list of losses (LossWrapper or t.nn.Module)

  • relationData (list) – list of MetaRelations

  • weightRelations (list) – list of weights associated to each loss

Returns

  • loss (real) – total loss

  • tmpLoss (list) – list containing the losses associated to each Relation

  • meta private:

countParams(parameters: list)int[source]

Method that counts the number of trainable parameters in the model.

Parameters

parameters (iterable) – The iterable containtaining the pytorch model parameters.

Returns

Return type

Number of parameters (int)

fit(relationList, epochs=100, batch_size=500, save_model_every=10, LOG=False, MUTE=True)[source]

Function that performs the training of the wrapped pytorch model. It is analogous to scikit-learn .fit() method.

Parameters
  • relationList (ERgraph) –

  • epochs (int) – Number of epochs

  • batch_size (int) – batch size during training

  • save_model_every (int) – Stores the model every int epochs

predict(ERgraph, X, metaRelationName, relationName, sidex1=None, sidex2=None, batch_size=500, plotGraph=False)[source]

Function that performs the training of the wrapped pytorch model. It is analogous to scikit-learn .predict() method.

Parameters
  • ERgraph (ERgraph) –

  • X (list) – List containing the 2D coordinates of the positions that should be predicted in the ERgraph.metaRelationName.relationName Relation.

  • metaRelationName (str) – Name of the MetaRelation that contains the target relation

  • relationName (str) – Name of the relation that you want to predict

  • batch_size (int) – batch size during prediction

Returns

yp – List containing the predictions for the target Relation

Return type

list

printBatchesLog(rel, e, bi, errTotOld, errTot, totLen, epochTime, loadTime, forwTime, LossTime, start, batch_size, mute=True)[source]

This class simplifies the live logging of the batches. If muted, it will only signal excessively long loading times.

Parameters

TODO

Returns

Return type

meta private:

processDatasets(DS: list)[source]

This method takes the external Entity Relation graph representation, in the form of one MetaRelation at a time and converts it into lower level data structures to be used within the wrapper, creating a MetaDataset structure from the ER representation passed as input. This structure mimics the ERgraph, but it’s suitable for efficient multi-task mini batching during training. This function is used internally by the NNwrapper and does not need to be called by the user.

Parameters

DS (MetaRelation) –

Returns

  • DS (MetaRelation) – The original MetaRelation without the data matrices, in an attempt to save space. (still have to run benchmarks on it)

  • datasets (list of SubDataset)

  • losses (list of losses)

  • refSize (size of the target matrix (to be removed))

  • meta private:

saveModel(e: int)[source]

Method that stores the trained model at a certain iteration. Used internally.

Parameters

e (int) – Epoch number. The model is automatically saved using the model name and the epoch number using t.save function.