diff --git a/relational/optimizations.py b/relational/optimizations.py index a23f7ed..d38976e 100644 --- a/relational/optimizations.py +++ b/relational/optimizations.py @@ -433,40 +433,24 @@ def select_union_intersect_subtract(n: parser.Node) -> int: return changes + recoursive_scan(select_union_intersect_subtract, n) -def union_and_product(n: parser.Node) -> int: +def union_and_product(n: parser.Node) -> Tuple[parser.Node, int]: ''' A * B ∪ A * C = A * (B ∪ C) Same thing with inner join ''' - - changes = 0 if n.name == UNION and n.left.name in {PRODUCT, JOIN} and n.left.name == n.right.name: - newnode = parser.Node() - newnode.kind = parser.BINARY - newnode.name = n.left.name - - newchild = parser.Node() - newchild.kind = parser.BINARY - newchild.name = UNION - if n.left.left == n.right.left or n.left.left == n.right.right: - newnode.left = n.left.left - newnode.right = newchild - - newchild.left = n.left.right - newchild.right = n.right.left if n.left.left == n.right.right else n.right.right - replace_node(n, newnode) - changes = 1 + l = n.left.right + r = n.right.left if n.left.left == n.right.right else n.right.right + newchild = parser.Binary(UNION, l, r) + return parser.Binary(n.left.name, n.left.left, newchild), 1 elif n.left.right == n.right.left or n.left.left == n.right.right: - newnode.left = n.left.right - newnode.right = newchild - - newchild.left = n.left.left - newchild.right = n.right.left if n.right.left == n.right.right else n.right.right - replace_node(n, newnode) - changes = 1 - return changes + recoursive_scan(union_and_product, n) + l = n.left.left + r = n.right.left if n.right.left == n.right.right else n.right.right + newchild = parser.Binary(UNION, l, r) + return parser.Binary(n.left.name, n.left.right, newchild), 1 + return n, 0 def projection_and_union(n, rels): @@ -498,7 +482,7 @@ def projection_and_union(n, rels): newnode.prop = n.right.prop replace_node(n, newnode) changes = 1 - return changes + recoursive_scan(projection_and_union, n, rels) + return n, 0 def selection_and_product(n, rels): @@ -623,7 +607,7 @@ general_optimizations = [ swap_union_renames, swap_rename_projection, #select_union_intersect_subtract, - #union_and_product, + union_and_product, ] specific_optimizations = [ #selection_and_product,