kse-02/operators.py

175 lines
5.6 KiB
Python

import sys
from dataclasses import dataclass
from typing import Generic, Dict, List, TypeVar, Callable, Union, Tuple
from nltk import edit_distance
distances_true: Dict[int, int] = {}
distances_false: Dict[int, int] = {}
distances_true_all: Dict[int, List[int]] = {}
distances_false_all: Dict[int, List[int]] = {}
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class CmpOp(Generic[T]):
operator: str
name: str
test: Callable[[T, T], bool]
true_dist: Callable[[T, T], int]
false_dist: Callable[[T, T], int]
def __init__(self, operator: str, name: str, test: Callable[[T, T], bool], true_dist: Callable[[T, T], int],
false_dist: Callable[[T, T], int]):
self.operator = operator
self.name = name
self.test = test
self.true_dist = true_dist
self.false_dist = false_dist
# Operands for these must both be integers or strings of length 1
int_str_ops: List[CmpOp[Union[int, str]]] = [
CmpOp(operator='<',
name='Lt',
test=lambda lhs, rhs: lhs < rhs,
true_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0,
false_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0),
CmpOp(operator='>',
name='Gt',
test=lambda lhs, rhs: lhs > rhs,
true_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0,
false_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0),
CmpOp(operator='<=',
name='LtE',
test=lambda lhs, rhs: lhs <= rhs,
true_dist=lambda lhs, rhs: lhs - rhs if lhs > rhs else 0,
false_dist=lambda lhs, rhs: rhs - lhs + 1 if lhs <= rhs else 0),
CmpOp(operator='>=',
name='GtE',
test=lambda lhs, rhs: lhs >= rhs,
true_dist=lambda lhs, rhs: rhs - lhs if lhs < rhs else 0,
false_dist=lambda lhs, rhs: lhs - rhs + 1 if lhs >= rhs else 0),
CmpOp(operator='==',
name='Eq',
test=lambda lhs, rhs: lhs == rhs,
true_dist=lambda lhs, rhs: abs(lhs - rhs),
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
CmpOp(operator='!=',
name='NotEq',
test=lambda lhs, rhs: lhs != rhs,
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
false_dist=lambda lhs, rhs: abs(lhs - rhs)),
]
int_str_by_name: Dict[str, CmpOp[Union[int, str]]] = {c.name: c for c in int_str_ops}
def int_str_check(a: any, b: any) -> bool:
if type(a) == int and type(b) == int:
return True
if type(a) != str or type(b) != str:
return False
return len(a) == 1 and len(b) == 1
def int_str_convert(x: Union[int, str]) -> int:
if type(x) == int:
return x
if len(x) == 1:
return ord(x)
raise ValueError("x must be int or len(str) == 1")
# Operands for these must both be strings
str_ops: List[CmpOp[str]] = [
CmpOp(operator='==',
name='Eq',
test=lambda lhs, rhs: lhs == rhs,
true_dist=lambda lhs, rhs: edit_distance(lhs, rhs),
false_dist=lambda lhs, rhs: 1 if lhs == rhs else 0),
CmpOp(operator='!=',
name='NotEq',
test=lambda lhs, rhs: lhs != rhs,
true_dist=lambda lhs, rhs: 1 if lhs == rhs else 0,
false_dist=lambda lhs, rhs: edit_distance(lhs, rhs)),
]
str_by_name: Dict[str, CmpOp[Union[int, str]]] = {c.name: c for c in str_ops}
def str_check(a: any, b: any) -> bool:
return type(a) == str and type(b) == str
def compute_distances(name: str, lhs: any, rhs: any) -> Tuple[int, int, bool]:
if int_str_check(lhs, rhs):
lhs_int = int_str_convert(lhs)
rhs_int = int_str_convert(rhs)
if name not in int_str_by_name:
raise ValueError(f"'{name}' is not a valid CmpOp name for 'int_str' operators")
op = int_str_by_name[name]
return op.true_dist(lhs_int, rhs_int), op.false_dist(lhs_int, rhs_int), op.test(lhs_int, rhs_int)
if str_check(lhs, rhs):
if name not in str_by_name:
raise ValueError(f"'{name}' is not a valid CmpOp name for 'str' operators")
op = str_by_name[name]
return op.true_dist(lhs, rhs), op.false_dist(lhs, rhs), op.test(lhs, rhs)
raise ValueError(f"'{lhs}' and '{rhs}' are not suitable for both 'int_str' and 'str' operators")
def update_map(the_map: Dict[int, int], condition_num: int, distance: int):
if condition_num in the_map:
the_map[condition_num] = min(the_map[condition_num], distance)
else:
the_map[condition_num] = distance
def update_maps(condition_num, d_true, d_false):
global distances_true, distances_false
update_map(distances_true, condition_num, d_true)
if condition_num not in distances_true_all:
distances_true_all[condition_num] = [d_true]
else:
distances_true_all[condition_num].append(d_true)
update_map(distances_false, condition_num, d_false)
if condition_num not in distances_false_all:
distances_false_all[condition_num] = [d_false]
else:
distances_false_all[condition_num].append(d_false)
def in_op(num, lhs, rhs):
if isinstance(lhs, str):
lhs = ord(lhs)
minimum = sys.maxsize
for elem in rhs.keys():
distance = abs(lhs - ord(elem))
if distance < minimum:
minimum = distance
distance_true, distance_false = minimum, 1 if minimum == 0 else 0
update_maps(num, distance_true, distance_false)
return distance_true == 0 # distance == 0 equivalent to actual test by construction
def evaluate_condition(num, op, lhs, rhs):
if op == "In":
return in_op(num, lhs, rhs)
distance_true, distance_false, test = compute_distances(op, lhs, rhs)
update_maps(num, distance_true, distance_false)
return test