Make a bunch of stuff private

Only export the entry points, not all the inner functions.
This commit is contained in:
Salvo 'LtWorf' Tomaselli 2020-08-26 17:26:55 +02:00
parent a727c51e75
commit 9d6402b48c
No known key found for this signature in database
GPG Key ID: B3A7CF0C801886CF
2 changed files with 23 additions and 36 deletions

View File

@ -49,13 +49,13 @@ def optimize_program(code: str, rels: Dict[str, Relation]) -> str:
res, query = UserInterface.split_query(line)
last_res = res
parsed = tree(query)
replace_leaves(parsed, context)
_replace_leaves(parsed, context)
context[res] = parsed
node = optimize_all(context[last_res], rels, tostr=False)
return querysplit.split(node, rels)
def replace_leaves(node: Node, context: Dict[str, Node]) -> None:
def _replace_leaves(node: Node, context: Dict[str, Node]) -> None:
'''
If a name appearing in node appears
also in context, the parse tree is
@ -63,13 +63,13 @@ def replace_leaves(node: Node, context: Dict[str, Node]) -> None:
subtree found in context.
'''
if isinstance(node, Unary):
replace_leaves(node.child, context)
_replace_leaves(node.child, context)
if isinstance(node.child, Variable) and node.child.name in context:
node.child = context[node.child.name]
elif isinstance(node, Binary):
replace_leaves(node.left, context)
replace_leaves(node.right, context)
_replace_leaves(node.left, context)
_replace_leaves(node.right, context)
if isinstance(node.left, Variable) and node.left.name in context:
node.left = context[node.left.name]
if isinstance(node.right, Variable) and node.right.name in context:
@ -93,20 +93,20 @@ def optimize_all(expression: Union[str, Node], rels: Dict[str, Relation], specif
elif isinstance(expression, Node):
n = expression
else:
raise (TypeError("expression must be a string or a node"))
raise TypeError('expression must be a string or a node')
total = 1
while total != 0:
total = 0
if specific:
for i in optimizations.specific_optimizations:
n, c = recursive_scan(i, n, rels)
n, c = _recursive_scan(i, n, rels)
if c != 0 and isinstance(debug, list):
debug.append(str(n))
total += c
if general:
for j in optimizations.general_optimizations:
n, c = recursive_scan(j, n, None)
n, c = _recursive_scan(j, n, None)
if c != 0 and isinstance(debug, list):
debug.append(str(n))
total += c
@ -116,28 +116,7 @@ def optimize_all(expression: Union[str, Node], rels: Dict[str, Relation], specif
return n
def specific_optimize(expression, rels: Dict[str, Relation]):
'''This function performs specific optimizations. Means that it will need to
know the fields used by the relations.
expression : see documentation of this module
rels: dic with relation name as key, and relation istance as value
Return value: this will return an optimized version of the expression'''
return optimize_all(expression, rels, specific=True, general=False)
def general_optimize(expression):
'''This function performs general optimizations. Means that it will not need to
know the fields used by the relations
expression : see documentation of this module
Return value: this will return an optimized version of the expression'''
return optimize_all(expression, None, specific=False, general=True)
def recursive_scan(function, node: Node, rels: Optional[Dict[str, Any]]) -> Tuple[Node, int]:
def _recursive_scan(function, node: Node, rels: Optional[Dict[str, Any]]) -> Tuple[Node, int]:
'''Does a recursive optimization on the tree.
This function will recursively execute the function given
@ -159,11 +138,11 @@ def recursive_scan(function, node: Node, rels: Optional[Dict[str, Any]]) -> Tupl
changes += c
if isinstance(node, Unary):
node.child, c = recursive_scan(function, node.child, rels)
node.child, c = _recursive_scan(function, node.child, rels)
changes += c
elif isinstance(node, Binary):
node.left, c = recursive_scan(function, node.left, rels)
node.left, c = _recursive_scan(function, node.left, rels)
changes += c
node.right, c = recursive_scan(function, node.right, rels)
node.right, c = _recursive_scan(function, node.right, rels)
changes += c
return node, changes

View File

@ -23,11 +23,14 @@ from typing import List, Dict, Tuple
from relational.parser import Node, Binary, Unary, Variable
__all__ = ['split']
class Program:
def __init__(self, rels) -> None:
self.queries: List[Tuple[str, Node]] = []
self.dictionary: Dict[str, Node] = {} # Key is the query, value is the relation
self.vgen = vargen(rels, 'optm_')
self.vgen = _vargen(rels, 'optm_')
def __str__(self):
r = ''
@ -48,6 +51,7 @@ class Program:
self.dictionary[strnode] = n
return n
def _separate(node: Node, program: Program) -> None:
if isinstance(node, Unary) and isinstance(node.child, Variable):
_separate(node.child, program)
@ -64,7 +68,8 @@ def _separate(node: Node, program: Program) -> None:
node.right = rel
program.append_query(node)
def vargen(avoid: str, prefix: str=''):
def _vargen(avoid: str, prefix: str=''):
'''
Generates temp variables.
@ -86,12 +91,15 @@ def vargen(avoid: str, prefix: str=''):
yield r
count += 1
def split(node, rels) -> str:
'''
Split a query into a program.
The idea is that if there are duplicated subdtrees they
The idea is that if there are duplicated subtrees they
get executed only once.
This is used by the optimizer module.
'''
p = Program(rels)
_separate(node, p)