Improve type annotations

This commit is contained in:
Salvo 'LtWorf' Tomaselli 2020-06-03 07:01:52 +02:00
parent 4ef4d679ac
commit 6a33e8ee08
No known key found for this signature in database
GPG Key ID: B3A7CF0C801886CF

View File

@ -22,7 +22,8 @@
import csv
from itertools import chain, repeat
from collections import deque
from typing import List, Union, Set
from typing import *
from pathlib import Path
from relational.rtypes import *
@ -53,11 +54,11 @@ class Relation:
'''
__hash__ = None # type: None
def __init__(self, filename: str = '') -> None:
def __init__(self, filename: Optional[Union[str, Path]] = None) -> None:
self._readonly = False
self.content: Set[tuple] = set()
if len(filename) == 0: # Empty relation
if filename is None: # Empty relation
self.header = Header([])
return
with open(filename) as fp:
@ -73,7 +74,7 @@ class Relation:
self._readonly = True
copy._readonly = True
def _make_writable(self, copy_content : bool = True) -> None:
def _make_writable(self, copy_content: bool = True) -> None:
'''If this relation is marked as readonly, this
method will copy the content to make it writable too
@ -92,7 +93,7 @@ class Relation:
def __contains__(self, key):
return key in self.content
def save(self, filename: str) -> None:
def save(self, filename: Union[Path, str]) -> None:
'''
Saves the relation in a file. Will save using the csv
format as defined in RFC4180.
@ -169,7 +170,7 @@ class Relation:
newt.content.add(i + j)
return newt
def projection(self, * attributes) -> 'Relation':
def projection(self, *attributes) -> 'Relation':
'''
Can be called in two different ways:
a.projection('field1','field2')
@ -200,7 +201,7 @@ class Relation:
newt.content.add(tuple(row))
return newt
def rename(self, params: 'Relation') -> 'Relation':
def rename(self, params: Dict[str, str]) -> 'Relation':
'''
Takes a dictionary.
@ -505,7 +506,7 @@ class Header(tuple):
def __repr__(self):
return "Header(%s)" % super(Header, self).__repr__()
def rename(self, params) -> 'Header':
def rename(self, params: Dict[str, str]) -> 'Header':
'''Returns a new header, with renamed fields.
params is a dictionary of {old:new} names
@ -525,15 +526,15 @@ class Header(tuple):
'''Returns how many attributes this header has in common with a given one'''
return len(set(self).intersection(set(other)))
def union(self, other) -> set:
def union(self, other: 'Header') -> Set[str]:
'''Returns the union of the sets of attributes with another header.'''
return set(self).union(set(other))
def intersection(self, other) -> set:
def intersection(self, other: 'Header') -> Set[str]:
'''Returns the set of common attributes with another header.'''
return set(self).intersection(set(other))
def getAttributesId(self, param) -> List[int]:
def getAttributesId(self, param: Iterable[str]) -> List[int]:
'''Returns a list with numeric index corresponding to field's name'''
try:
return [self.index(i) for i in param]