#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# pytorchDatasetUtils.py
#
# Copyright 2018 Daniele Raimondi <daniele.raimondi@vub.be>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA 02110-1301, USA.
#
#
import pickle as cPickle
import time
import torch as t
import numpy as np
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
class PredictionDatasetSide(Dataset):
"""
:meta private:
"""
def __init__(self, x, sx1, sx2):
#if sx1 != None and len(sx1) > 0:
# assert len(x) == len(sx1)
#if sx2 != None and len(sx2) > 0:
# assert len(sx2) == len(x)
self.x = x # [(i,j), (i,k), ...]
self.sx1 = sx1
self.sx2 = sx2
def __getitem__(self, idx):
"""
:meta private:
"""
tmp = self.x[idx]
sx1 = []
sx2 = []
if self.sx1 != None:
sx1 = self.sx1[tmp[0]]
if self.sx2 != None:
sx2 = self.sx2[tmp[1]]
return tmp[0], tmp[1], sx1, sx2
def __len__(self):
return len(self.x)
[docs]class PredictionDataset(Dataset):
def __init__(self, x, label = True):
self.x = x # [(i,j), (i,k), ...]
def __getitem__(self, idx):
tmp = self.x[idx]
return tmp[0], tmp[1]
def __len__(self):
return len(self.x)
[docs]class SideDataset(Dataset):
def __init__(self, side):
assert type(side) == {}
self.side = side
self.estSize = len(self) * len(self.values()[0])
def __getitem__(self, idx):
return self.side[idx]
def __len__(self):
return len(self.side)
[docs]class SubDataset(Dataset):
"""
Within the NNwrapper, during training, batches need to be rapidly provided for all the MetaRelations in the ERgraph and for each Relation in every MetaRelation. To do so, the NNwrapper.processDatasets function builds an internal Dataset structure that mimicks the structure of the input ERgraph. In this case, MetaDataset correspond to MetaRelation, and each Relation in a MetaRelation is represendet by a SubDataset in the corresponding MetaDataset.
Nevertheless, this is internal and it is transparent to the user.
:meta private:
"""
[docs] def __init__(self, xht, typep="binary"):
"""
Constructor method for the SubDataset class. It puts in a pytorch-friendly structure the matrix corresponding to a target Relation, by transforming its DataMatrix into a pytorch Dataset.
Parameters
----------
xht : dict
Dict used to represent the matrix/relation data within a DataMatrix object
type : str
String specifying the type of the prediction. It must be "regression" or "binary".
Returns
-------
"""
#print type(xht)
assert type(xht) == dict
self.xht = xht #xht = {p1:[(positions),(values)]}
empty = 0
for i in self.xht.items():
if len(i[1][0]) == 0:
#print i[0]
empty += 1
#print i[1]
print ("Empty rows: ",empty)
#raw_input()
self.estSize = self.countInstances()
self.type = typep
self.balance = self.countBalance()
[docs] def countBalance(self):
if self.type != "binary":
return "regression"
#raise Exception("CAlling count balance on regression problem!!!")
r = [0,0]
for i in self.xht.values():
tmp = sum(i[1])
r[0] += tmp
r[1] += len(i[1])-tmp
return r
[docs] def countInstances(self):
r = 0
for i in self.xht.values():
r += len(i[0])
return r
def __getitem__(self, idx):
tmp = self.xht[idx]
return tmp
#return (np.array(tmp[0], dtype=np.int16), np.array(tmp[1], dtype=np.int8))
[docs] @staticmethod
def load(name):
tmp = cPickle.load(open(name))
vt = None
if "binary" in name:
vt = "binary"
elif "regression" in name:
vt = "regression"
return SubDataset(tmp, typep=vt)
[docs] def dump(self, name):
print( "Dumping...")
t1 = time.time()
cPickle.dump(self.xht, open(name, "w"))
t2 = time.time()
print ("Stored in: %s (%.2fs)" % ( name, t2-t1))
def __len__(self):
return len(self.xht)