Improve type annotations

This commit is contained in:
Salvo 'LtWorf' Tomaselli 2020-06-03 07:01:52 +02:00
parent 4ef4d679ac
commit 6a33e8ee08
No known key found for this signature in database
GPG Key ID: B3A7CF0C801886CF

View File

@ -22,7 +22,8 @@
import csv import csv
from itertools import chain, repeat from itertools import chain, repeat
from collections import deque from collections import deque
from typing import List, Union, Set from typing import *
from pathlib import Path
from relational.rtypes import * from relational.rtypes import *
@ -53,11 +54,11 @@ class Relation:
''' '''
__hash__ = None # type: None __hash__ = None # type: None
def __init__(self, filename: str = '') -> None: def __init__(self, filename: Optional[Union[str, Path]] = None) -> None:
self._readonly = False self._readonly = False
self.content: Set[tuple] = set() self.content: Set[tuple] = set()
if len(filename) == 0: # Empty relation if filename is None: # Empty relation
self.header = Header([]) self.header = Header([])
return return
with open(filename) as fp: with open(filename) as fp:
@ -73,7 +74,7 @@ class Relation:
self._readonly = True self._readonly = True
copy._readonly = True copy._readonly = True
def _make_writable(self, copy_content : bool = True) -> None: def _make_writable(self, copy_content: bool = True) -> None:
'''If this relation is marked as readonly, this '''If this relation is marked as readonly, this
method will copy the content to make it writable too method will copy the content to make it writable too
@ -92,7 +93,7 @@ class Relation:
def __contains__(self, key): def __contains__(self, key):
return key in self.content return key in self.content
def save(self, filename: str) -> None: def save(self, filename: Union[Path, str]) -> None:
''' '''
Saves the relation in a file. Will save using the csv Saves the relation in a file. Will save using the csv
format as defined in RFC4180. format as defined in RFC4180.
@ -169,7 +170,7 @@ class Relation:
newt.content.add(i + j) newt.content.add(i + j)
return newt return newt
def projection(self, * attributes) -> 'Relation': def projection(self, *attributes) -> 'Relation':
''' '''
Can be called in two different ways: Can be called in two different ways:
a.projection('field1','field2') a.projection('field1','field2')
@ -200,7 +201,7 @@ class Relation:
newt.content.add(tuple(row)) newt.content.add(tuple(row))
return newt return newt
def rename(self, params: 'Relation') -> 'Relation': def rename(self, params: Dict[str, str]) -> 'Relation':
''' '''
Takes a dictionary. Takes a dictionary.
@ -505,7 +506,7 @@ class Header(tuple):
def __repr__(self): def __repr__(self):
return "Header(%s)" % super(Header, self).__repr__() return "Header(%s)" % super(Header, self).__repr__()
def rename(self, params) -> 'Header': def rename(self, params: Dict[str, str]) -> 'Header':
'''Returns a new header, with renamed fields. '''Returns a new header, with renamed fields.
params is a dictionary of {old:new} names params is a dictionary of {old:new} names
@ -525,15 +526,15 @@ class Header(tuple):
'''Returns how many attributes this header has in common with a given one''' '''Returns how many attributes this header has in common with a given one'''
return len(set(self).intersection(set(other))) return len(set(self).intersection(set(other)))
def union(self, other) -> set: def union(self, other: 'Header') -> Set[str]:
'''Returns the union of the sets of attributes with another header.''' '''Returns the union of the sets of attributes with another header.'''
return set(self).union(set(other)) return set(self).union(set(other))
def intersection(self, other) -> set: def intersection(self, other: 'Header') -> Set[str]:
'''Returns the set of common attributes with another header.''' '''Returns the set of common attributes with another header.'''
return set(self).intersection(set(other)) return set(self).intersection(set(other))
def getAttributesId(self, param) -> List[int]: def getAttributesId(self, param: Iterable[str]) -> List[int]:
'''Returns a list with numeric index corresponding to field's name''' '''Returns a list with numeric index corresponding to field's name'''
try: try:
return [self.index(i) for i in param] return [self.index(i) for i in param]