Make header inherit from tupl

Rather than having a header class that contains a list of header,
change it to directly be an immutable tuple.

This simplifies the code because header can now be compared and
indexed like any other tuple.

Code had to be changed all over the place to cope with this new
datatype.
This commit is contained in:
Salvo 'LtWorf' Tomaselli
2015-06-06 15:22:11 +02:00
parent d754a166a1
commit a033cb769a
2 changed files with 59 additions and 88 deletions

View File

@@ -218,7 +218,7 @@ class node (object):
return
if self.kind == RELATION:
return list(rels[self.name].header.attributes)
return list(rels[self.name].header)
elif self.kind == BINARY and self.name in (DIFFERENCE, UNION, INTERSECTION):
return self.left.result_format(rels)
elif self.kind == BINARY and self.name == DIVISION:

View File

@@ -86,7 +86,7 @@ class relation (object):
writer = csv.writer(fp) # Creating csv writer
# It wants an iterable containing iterables
head = (self.header.attributes,)
head = (self.header,)
writer.writerows(head)
# Writing content, already in the correct format
@@ -101,10 +101,10 @@ class relation (object):
Will return None if they don't share the same attributes'''
if (self.__class__ != other.__class__):
raise Exception('Expected an instance of the same class')
if self.header.sharedAttributes(other.header) == len(self.header.attributes):
return other.projection(list(self.header.attributes))
if self.header.sharedAttributes(other.header) == len(self.header):
return other.projection(self.header)
raise Exception('Relations differ: [%s] [%s]' % (
','.join(self.header.attributes) , ','.join(other.header.attributes)
','.join(self.header) , ','.join(other.header)
))
def selection(self, expr):
@@ -112,11 +112,11 @@ class relation (object):
constant, math operations and boolean ones.'''
attributes = {}
newt = relation()
newt.header = header(list(self.header.attributes))
newt.header = header(self.header)
for i in self.content:
# Fills the attributes dictionary with the values of the tuple
for j in range(len(self.header.attributes)):
attributes[self.header.attributes[j]] = i[j].autocast()
for j,attr in enumerate(self.header):
attributes[attr] = i[j].autocast()
try:
if eval(expr, attributes):
@@ -136,7 +136,7 @@ class relation (object):
raise Exception(
'Unable to perform product on relations with colliding attributes')
newt = relation()
newt.header = header(self.header.attributes + other.header.attributes)
newt.header = header(self.header + other.header)
for i in self.content:
for j in other.content:
@@ -149,16 +149,16 @@ class relation (object):
Will delete duplicate items
If an empty list or no parameters are provided, returns None'''
# Parameters are supplied in a list, instead with multiple parameters
if isinstance(attributes[0], list):
if not isinstance(attributes[0], str):
attributes = attributes[0]
ids = self.header.getAttributesId(attributes)
if len(ids) == 0 or len(ids) != len(attributes):
if len(ids) == 0:
raise Exception('Invalid attributes for projection')
newt = relation()
# Create the header
h = [self.header.attributes[i] for i in ids]
h = (self.header[i] for i in ids)
newt.header = header(h)
# Create the body
@@ -175,10 +175,7 @@ class relation (object):
result = []
newt = relation()
newt.header = header(list(self.header.attributes))
for old, new in params.items():
newt.header.rename(old, new)
newt.header = self.header.rename(params)
newt.content = self.content
newt._readonly = True
@@ -193,7 +190,7 @@ class relation (object):
It is possible to use projection and rename to make headers match.'''
other = self._rearrange_(other) # Rearranges attributes' order
newt = relation()
newt.header = header(list(self.header.attributes))
newt.header = header(self.header)
newt.content = self.content.intersection(other.content)
return newt
@@ -206,7 +203,7 @@ class relation (object):
It is possible to use projection and rename to make headers match.'''
other = self._rearrange_(other) # Rearranges attributes' order
newt = relation()
newt.header = header(list(self.header.attributes))
newt.header = header(self.header)
newt.content = self.content.difference(other.content)
return newt
@@ -221,8 +218,7 @@ class relation (object):
'''
# d_headers are the headers from self that aren't also headers in other
d_headers = list(
set(self.header.attributes) - set(other.header.attributes))
d_headers = tuple(set(self.header) - set(other.header))
'''
Wikipedia defines the division as follows:
@@ -249,7 +245,7 @@ class relation (object):
It is possible to use projection and rename to make headers match.'''
other = self._rearrange_(other) # Rearranges attributes' order
newt = relation()
newt.header = header(list(self.header.attributes))
newt.header = header(self.header)
newt.content = self.content.union(other.content)
return newt
@@ -283,14 +279,11 @@ class relation (object):
shared = self.header.intersection(other.header)
newt = relation() # Creates the new relation
# Creating the header with all the fields, done like that because order is
# needed
h = (i for i in other.header if i not in shared)
newt.header = header(chain(self.header,h))
# Adds all the attributes of the 1st relation
newt.header = header(list(self.header.attributes))
# Adds all the attributes of the 2nd, when non shared
for i in other.header:
if i not in shared:
newt.header.attributes.append(i)
# Shared ids of self
sid = self.header.getAttributesId(shared)
# Shared ids of the other relation
@@ -324,16 +317,14 @@ class relation (object):
shared attributes, it will behave as cartesian product.'''
# List of attributes in common between the relations
shared = set(self.header).intersection(set(other.header))
shared = self.header.intersection(other.header)
newt = relation() # Creates the new relation
# Adding to the headers all the fields, done like that because order is
# Creating the header with all the fields, done like that because order is
# needed
newt.header = header(list(self.header.attributes))
for i in other.header.attributes:
if i not in shared:
newt.header.attributes.append(i)
h = (i for i in other.header if i not in shared)
newt.header = header(chain(self.header,h))
# Shared ids of self
sid = self.header.getAttributesId(shared)
@@ -341,10 +332,7 @@ class relation (object):
oid = other.header.getAttributesId(shared)
# Non shared ids of the other relation
noid = []
for i in range(len(other.header.attributes)):
if i not in oid:
noid.append(i)
noid = [i for i in range(len(other.header)) if i not in oid]
for i in self.content:
for j in other.content:
@@ -364,7 +352,7 @@ class relation (object):
if self.__class__ != other.__class__:
return False
if set(self.header.attributes) != set(other.header.attributes):
if set(self.header) != set(other.header):
return False
# Rearranges attributes' order so can compare tuples directly
@@ -380,7 +368,7 @@ class relation (object):
'''Returns a string representation of the relation, can be printed with
monospaced fonts'''
m_len = [] # Maximum lenght string
for f in self.header.attributes:
for f in self.header:
m_len.append(len(f))
for f in self.content:
@@ -391,8 +379,8 @@ class relation (object):
col += 1
res = ""
for f in range(len(self.header.attributes)):
res += "%s" % (self.header.attributes[f].ljust(2 + m_len[f]))
for f,attr in enumerate(self.header):
res += "%s" % (attr.ljust(2 + m_len[f]))
for r in self.content:
col = 0
@@ -420,8 +408,8 @@ class relation (object):
# new_content=[] #New content of the relation
for i in self.content:
for j in range(len(self.header.attributes)):
attributes[self.header.attributes[j]] = i[j].autocast()
for j,attr in enumerate(self.header):
attributes[attr] = i[j].autocast()
if eval(expr, attributes): # If expr is true, changing the tuple
affected += 1
@@ -441,10 +429,10 @@ class relation (object):
All the values will be converted in string.
Will return the number of inserted rows.'''
if len(self.header.attributes) != len(values):
if len(self.header) != len(values):
raise Exception(
'Tuple has the wrong size. Expected %d, got %d' % (
len(self.header.attributes),
len(self.header),
len(values)
)
)
@@ -470,73 +458,56 @@ class relation (object):
return len(self.content) - l
class header (object):
class header(tuple):
'''This class defines the header of a relation.
It is used within relations to know if requested operations are accepted'''
# Since relations are mutalbe we explicitly block hashing them
__hash__ = None
def __new__ (cls, fields):
return super(header, cls).__new__(cls, tuple(fields))
def __init__(self, attributes):
def __init__(self, *args, **kwargs):
'''Accepts a list with attributes' names. Names MUST be unique'''
self.attributes = attributes
for i in attributes:
for i in self:
if not is_valid_relation_name(i):
raise Exception('"%s" is not a valid attribute name' % i)
if len(attributes) != len(set(attributes)):
if len(self) != len(set(self)):
raise Exception('Attribute names must be unique')
def __repr__(self):
return "header(%s)" % (self.attributes.__repr__())
return "header(%s)" % super(header, self).__repr__()
def rename(self, old, new):
'''Renames a field. Doesn't check if it is a duplicate.
Returns True'''
def rename(self, params):
'''Returns a new header, with renamed fields.
if not is_valid_relation_name(new):
raise Exception('%s is not a valid attribute name' % new)
try:
id_ = self.attributes.index(old)
self.attributes[id_] = new
except:
raise Exception('Field not found: %s' & old)
return True
params is a dictionary of {old:new} names
'''
attrs = list(self)
for old,new in params.items():
if not is_valid_relation_name(new):
raise Exception('%s is not a valid attribute name' % new)
try:
id_ = attrs.index(old)
attrs[id_] = new
except:
raise Exception('Field not found: %s' % old)
return header(attrs)
def sharedAttributes(self, other):
'''Returns how many attributes this header has in common with a given one'''
return len(set(self.attributes).intersection(set(other.attributes)))
return len(set(self).intersection(set(other)))
def union(self, other):
'''Returns the union of the sets of attributes with another header.'''
return set(self.attributes).union(set(other.attributes))
return set(self).union(set(other))
def intersection(self, other):
'''Returns the set of common attributes with another header.'''
return set(self.attributes).intersection(set(other.attributes))
def __str__(self):
'''Returns String representation of the field's list'''
return self.attributes.__str__()
def __eq__(self, other):
return self.attributes == other.attributes
def __ne__(self, other):
return self.attributes != other.attributes
def __contains__(self, key):
return key in self.attributes
def __iter__(self):
return iter(self.attributes)
def __len__(self):
return len(self.attributes)
return set(self).intersection(set(other))
def getAttributesId(self, param):
'''Returns a list with numeric index corresponding to field's name'''
return [self.attributes.index(i) for i in param]
return [self.index(i) for i in param]