NXTfusion.NXmultiRelSide module¶
-
class
NXTfusion.NXmultiRelSide.
NNwrapper
(model, dev, ignore_index, initialEpoch=0, nworkers=0)[source]¶ Bases:
object
Class that wraps a t.nn.Module (pytorch module) and uses scikit-learn-like methods such as .fit() and .predict() to train and test it.
-
__init__
(model, dev, ignore_index, initialEpoch=0, nworkers=0)[source]¶ Constructor for the NNWrapper class, which facilitates and standardizes the training of pytorch neural networks.
- Parameters
model (t.nn.Module) – The pytorch Neural Network that should be trained or tested.
def (t.device) – The device on which the model should run. E.g. t.device(“cuda”) or t.device(“cpu:0”)
ignore_index (int) – The ignore index value that will be used to mark “missing values” and “N/A” on partially observed matrices, in order to let the corresponding loss ignore those instances.
-
computeLosses
(y, yp, losses, relationData, weightRelations)[source]¶ This function computes the losses for the entire ER graph, by iterating through them. Used internally.
- Parameters
y (t.tensor) – Pytorch tensor containing the labels
yp (t.tensor) – Pytorch tensor containing the predictions
losses (list) – list of losses (LossWrapper or t.nn.Module)
relationData (list) – list of MetaRelations
weightRelations (list) – list of weights associated to each loss
- Returns
loss (real) – total loss
tmpLoss (list) – list containing the losses associated to each Relation
meta private:
-
countParams
(parameters: list) → int[source]¶ Method that counts the number of trainable parameters in the model.
- Parameters
parameters (iterable) – The iterable containtaining the pytorch model parameters.
- Returns
- Return type
Number of parameters (int)
-
fit
(relationList, epochs=100, batch_size=500, save_model_every=10, LOG=False, MUTE=True)[source]¶ Function that performs the training of the wrapped pytorch model. It is analogous to scikit-learn .fit() method.
- Parameters
relationList (ERgraph) –
epochs (int) – Number of epochs
batch_size (int) – batch size during training
save_model_every (int) – Stores the model every int epochs
-
predict
(ERgraph, X, metaRelationName, relationName, sidex1=None, sidex2=None, batch_size=500, plotGraph=False)[source]¶ Function that performs the training of the wrapped pytorch model. It is analogous to scikit-learn .predict() method.
- Parameters
ERgraph (ERgraph) –
X (list) – List containing the 2D coordinates of the positions that should be predicted in the ERgraph.metaRelationName.relationName Relation.
metaRelationName (str) – Name of the MetaRelation that contains the target relation
relationName (str) – Name of the relation that you want to predict
batch_size (int) – batch size during prediction
- Returns
yp – List containing the predictions for the target Relation
- Return type
list
-
printBatchesLog
(rel, e, bi, errTotOld, errTot, totLen, epochTime, loadTime, forwTime, LossTime, start, batch_size, mute=True)[source]¶ This class simplifies the live logging of the batches. If muted, it will only signal excessively long loading times.
- Parameters
TODO –
- Returns
- Return type
meta private:
-
processDatasets
(DS: list)[source]¶ This method takes the external Entity Relation graph representation, in the form of one MetaRelation at a time and converts it into lower level data structures to be used within the wrapper, creating a MetaDataset structure from the ER representation passed as input. This structure mimics the ERgraph, but it’s suitable for efficient multi-task mini batching during training. This function is used internally by the NNwrapper and does not need to be called by the user.
- Parameters
DS (MetaRelation) –
- Returns
DS (MetaRelation) – The original MetaRelation without the data matrices, in an attempt to save space. (still have to run benchmarks on it)
datasets (list of SubDataset)
losses (list of losses)
refSize (size of the target matrix (to be removed))
meta private:
-