Implement proper semantic checking for full AST

This commit is contained in:
Bananymous 2024-04-26 13:36:20 +03:00
parent 7a306b7904
commit 04a899247e
4 changed files with 947 additions and 0 deletions

View File

@ -0,0 +1,306 @@
#!/usr/bin/env python3
import lexer
import ply.lex as lex
import ply.yacc as yacc
tokens = lexer.tokens
class ASTnode:
def __init__(self, typestr, lineno, value = None):
self.type = None
self.nodetype = typestr
self.lineno = lineno
if value is not None:
self.value = value
def p_program1(p):
'program : statement_list'
p[0] = ASTnode('program', p.lineno(1))
p[0].children_definitions = []
p[0].children_statements = p[1].children_statements
def p_program2(p):
'program : definition_list statement_list'
p[0] = ASTnode('program', p.lineno(1))
p[0].children_definitions = p[1].children_definitions
p[0].children_statements = p[2].children_statements
def p_statement_list1(p):
'statement_list : statement'
p[0] = ASTnode('statement_list', p.lineno(1))
p[0].children_statements = [ p[1] ]
def p_statement_list2(p):
'statement_list : statement_list COMMA statement'
p[0] = p[1]
p[0].children_statements += [ p[3] ]
def p_definition_list1(p):
'definition_list : definition'
p[0] = ASTnode('definition_list', p.lineno(1))
p[0].children_definitions = [ p[1] ]
def p_definition_list2(p):
'definition_list : definition_list definition'
p[0] = p[1]
p[0].children_definitions += [ p[2] ]
def p_definition(p):
'''definition : function_definition
| procedure_definition
| variable_definition'''
p[0] = p[1]
def p_variable_definition(p):
'variable_definition : VAR IDENT EQ expression'
p[0] = ASTnode('variable_definition', p.lineno(1), p[2])
p[0].child_expression = p[4]
def p_empty(p):
'empty :'
pass
def p_variable_definition_list1(p):
'variable_definition_list : empty'
p[0] = ASTnode('variable_definition_list', p.lineno(1))
p[0].children_definitions = []
def p_variable_definition_list2(p):
'variable_definition_list : variable_definition_list variable_definition'
p[0] = p[1]
p[0].children_definitions += [ p[2] ]
def p_function_definition(p):
'''function_definition : FUNCTION FUNC_IDENT LCURLY formal_list RCURLY RETURN IDENT variable_definition_list IS rvalue END FUNCTION'''
p[0] = ASTnode('function_definition', p.lineno(2), p[2])
p[0].children_formals = p[4].children_formals
p[0].child_return_type = p[7]
p[0].children_variable_definitions = p[8].children_definitions
p[0].child_expression = p[10]
def p_procedure_definition1(p):
'procedure_definition : PROCEDURE PROC_IDENT LCURLY formal_list RCURLY variable_definition_list IS statement_list END PROCEDURE'
p[0] = ASTnode('procedure_definition', p.lineno(2), p[2])
p[0].children_formals = p[4].children_formals
p[0].children_variable_definitions = p[6].children_definitions
p[0].children_statements = p[8].children_statements
p[0].child_return_type = None
def p_procedure_definition2(p):
'''procedure_definition : PROCEDURE PROC_IDENT LCURLY formal_list RCURLY RETURN IDENT variable_definition_list IS statement_list END PROCEDURE'''
p[0] = ASTnode('procedure_definition', p.lineno(2), p[2])
p[0].children_formals = p[4].children_formals
p[0].children_variable_definitions = p[8].children_definitions
p[0].children_statements = p[10].children_statements
p[0].child_return_type = p[7]
def p_formal_list1(p):
'formal_list : empty'
p[0] = ASTnode('formal_list', p.lineno(1))
p[0].children_formals = []
def p_formal_list2(p):
'formal_list : formal_arg'
p[0] = ASTnode('formal_list', p.lineno(1))
p[0].children_formals = [ p[1] ]
def p_formal_list3(p):
'formal_list : formal_list COMMA formal_arg'
p[0] = p[1]
p[0].children_formals += [ p[3] ]
def p_formal_arg(p):
'formal_arg : IDENT LSQUARE IDENT RSQUARE'
p[0] = ASTnode('formal_argument', p.lineno(1), p[1])
p[0].type = p[3]
def p_procedure_call1(p):
'procedure_call : PROC_IDENT LPAREN RPAREN'
p[0] = ASTnode('procedure_call', p.lineno(1), p[1])
p[0].children_arguments = []
def p_procedure_call(p):
'''procedure_call : PROC_IDENT LPAREN arguments RPAREN'''
p[0] = ASTnode('procedure_call', p.lineno(1), p[1])
p[0].children_arguments = p[3].children_arguments
def p_arguments1(p):
'arguments : expression'
p[0] = ASTnode('arguments', p.lineno(1))
p[0].children_arguments = [ p[1] ]
def p_arguments2(p):
'arguments : arguments COMMA expression'
p[0] = p[1]
p[0].children_arguments += [ p[3] ]
def p_assignment(p):
'assignment : lvalue EQ rvalue'
p[0] = ASTnode('assignment', p.lineno(2))
p[0].child_lhs = p[1]
p[0].child_rhs = p[3]
def p_lvalue1(p):
'lvalue : IDENT'
p[0] = ASTnode('identifier', p.lineno(1), p[1])
def p_lvalue2(p):
'lvalue : IDENT DOT IDENT'
p[0] = ASTnode('attribute_write', p.lineno(1))
p[0].child_identifier = ASTnode('identifier', p.lineno(1), p[1])
p[0].child_attribute = ASTnode('identifier', p.lineno(3), p[3])
def p_rvalue(p):
'''rvalue : expression
| unless_expression'''
p[0] = p[1]
def p_print_statement1(p):
'print_statement : PRINT print_item'
p[0] = ASTnode('print', p.lineno(1))
p[0].children_items = [ p[2] ]
def p_print_statement2(p):
'print_statement : print_statement AMPERSAND print_item'
p[0] = p[1]
p[0].children_items += [ p[3] ]
def p_print_item1(p):
'print_item : STRING'
p[0] = ASTnode('string_literal', p.lineno(1), p[1])
def p_print_item2(p):
'print_item : expression'
p[0] = p[1]
def p_statement1(p):
'''statement : procedure_call
| assignment
| print_statement'''
p[0] = p[1]
def p_statement2(p):
'statement : DO statement_list UNTIL expression'
p[0] = ASTnode('do_until', p.lineno(1))
p[0].children_statements = p[2].children_statements
p[0].child_condition = p[4]
def p_statement3(p):
'statement : DO statement_list UNLESS expression DONE'
p[0] = ASTnode('do_unless', p.lineno(1))
p[0].children_statements_false = p[2].children_statements
p[0].child_condition = p[4]
p[0].children_statements_true = []
def p_statement4(p):
'statement : DO statement_list UNLESS expression OTHERWISE statement_list DONE'
p[0] = ASTnode('do_unless', p.lineno(1))
p[0].children_statements_false = p[2].children_statements
p[0].child_condition = p[4]
p[0].children_statements_true = p[6].children_statements
def p_statement5(p):
'statement : RETURN expression'
p[0] = ASTnode('return', p.lineno(1))
p[0].child_expression = p[2]
def p_expression1(p):
'expression : simple_expr'
p[0] = p[1]
def p_expression2(p):
'''expression : expression EQ simple_expr
| expression LT simple_expr'''
p[0] = ASTnode('binary_op', p.lineno(2), p[2])
p[0].child_lhs = p[1]
p[0].child_rhs = p[3]
def p_simple_expr1(p):
'simple_expr : term'
p[0] = p[1]
def p_simple_expr2(p):
'''simple_expr : simple_expr PLUS term
| simple_expr MINUS term'''
p[0] = ASTnode('binary_op', p.lineno(2), p[2])
p[0].child_lhs = p[1]
p[0].child_rhs = p[3]
def p_term1(p):
'term : factor'
p[0] = p[1]
def p_term2(p):
'''term : term MULT factor
| term DIV factor'''
p[0] = ASTnode('binary_op', p.lineno(2), p[2])
p[0].child_lhs = p[1]
p[0].child_rhs = p[3]
def p_factor1(p):
'factor : atom'
p[0] = p[1]
def p_factor2(p):
'''factor : MINUS atom
| PLUS atom'''
p[0] = ASTnode('unary_op', p.lineno(1), p[1])
p[0].child_atom = p[2]
def p_atom1(p):
'atom : IDENT'
p[0] = ASTnode('identifier', p.lineno(1), p[1])
def p_atom2(p):
'atom : INT_LITERAL'
p[0] = ASTnode('int_literal', p.lineno(1), p[1])
def p_atom3(p):
'atom : DATE_LITERAL'
p[0] = ASTnode('date_literal', p.lineno(1), p[1])
def p_atom4(p):
'atom : IDENT APOSTROPHE IDENT'
p[0] = ASTnode('attribute_read', p.lineno(1))
p[0].child_identifier = ASTnode('identifier', p.lineno(1), p[1])
p[0].child_attribute = ASTnode('identifier', p.lineno(3), p[3])
def p_atom5(p):
'atom : LPAREN expression RPAREN'
p[0] = p[2]
def p_atom6(p):
'''atom : function_call
| procedure_call'''
p[0] = p[1]
def p_function_call1(p):
'function_call : FUNC_IDENT LPAREN RPAREN'
p[0] = ASTnode('function_call', p.lineno(1), p[1])
p[0].children_arguments = []
def p_function_call2(p):
'function_call : FUNC_IDENT LPAREN arguments RPAREN'
p[0] = ASTnode('function_call', p.lineno(1), p[1])
p[0].children_arguments = p[3].children_arguments
def p_unless_expression(p):
'unless_expression : DO expression UNLESS expression OTHERWISE expression DONE'
p[0] = ASTnode('unless_expression', p.lineno(1))
p[0].child_condition = p[4]
p[0].child_expression_true = p[2]
p[0].child_expression_false = p[6]
def p_error(p):
if p is not None:
print(f"{{{p.lexer.lineno}}}:Syntax Error (token:'{p.value}')")
else:
print('Syntax Error at the end of file')
raise SystemExit
def syntax_check_file(file_path: str, debug: bool) -> ASTnode:
parser = yacc.yacc()
with open(file_path, 'r', encoding='utf-8') as file:
result = parser.parse(file.read(), lexer=lexer.lexer, debug=debug)
return result

View File

@ -0,0 +1,132 @@
#!/bin/env python3
import argparse
import datetime
import ply.lex as lex
reserved = {
'var': 'VAR',
'is': 'IS',
'unless': 'UNLESS',
'otherwise': 'OTHERWISE',
'until': 'UNTIL',
'do': 'DO',
'done': 'DONE',
'procedure': 'PROCEDURE',
'function': 'FUNCTION',
'return': 'RETURN',
'print': 'PRINT',
'end': 'END',
}
tokens = [
'LPAREN',
'RPAREN',
'LSQUARE',
'RSQUARE',
'LCURLY',
'RCURLY',
'APOSTROPHE',
'AMPERSAND',
'COMMA',
'DOT',
'EQ',
'LT',
'PLUS',
'MINUS',
'MULT',
'DIV',
'STRING',
'DATE_LITERAL',
'INT_LITERAL',
'IDENT',
'FUNC_IDENT',
'PROC_IDENT',
] + list(reserved.values())
def t_whitespace(t):
r'[ \t\n]+'
t.lexer.lineno += t.value.count('\n')
def t_comment(t):
r'\(%(.|\n)*?%\)'
t.lexer.lineno += t.value.count('\n')
t_LPAREN = r'\('
t_RPAREN = r'\)'
t_LSQUARE = r'\['
t_RSQUARE = r'\]'
t_LCURLY = r'\{'
t_RCURLY = r'\}'
t_APOSTROPHE = r'\''
t_AMPERSAND = r'&'
t_COMMA = r','
t_DOT = r'\.'
t_EQ = r'='
t_LT = r'<'
t_PLUS = r'\+'
t_MINUS = r'-'
t_MULT = r'\*'
t_DIV = r'/'
def t_STRING(t):
r'".*?"'
t.value = t.value[1:-1]
return t
def t_DATE_LITERAL(t):
r'\d{4}-\d{2}-\d{2}'
try:
t.value = datetime.date.fromisoformat(t.value)
except:
print(f'Invalid date \'{t.value}\' at line {t.lexer.lineno}')
raise SystemExit
return t
def t_INT_LITERAL(t):
r'-?\d{1,3}(\'\d{3})*'
t.value = int(t.value.replace('\'', ''))
return t
def t_IDENT(t):
r'[a-z][a-zA-Z0-9_]+'
t.type = reserved.get(t.value, 'IDENT')
return t
def t_FUNC_IDENT(t):
r'[A-Z][a-z0-9_]+'
return t
def t_PROC_IDENT(t):
r'[A-Z]{2}[A-Z0-9_]*'
return t
def t_error(t):
print(f'Illegal character \'{t.value[0]}\' at line {t.lexer.lineno}')
raise SystemExit
lexer = lex.lex()
def tokenize_file(file_path: str):
with open(file_path, 'r', encoding='utf-8') as file:
lexer.input(file.read())
tok = lexer.token()
while tok:
print(tok)
tok = lexer.token()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--who', action='store_true', help='print out student IDs and NAMEs of authors')
group.add_argument('-f', '--file', help='filename to process')
args = parser.parse_args()
if args.who:
print('Author')
print(' Student ID: 150189237')
print(' Name: Oskari Alaranta')
else:
tokenize_file(args.file)

View File

@ -0,0 +1,300 @@
#!/bin/env python3
import argparse
from copy import deepcopy
from datetime import date, timedelta
import tree_print
from build_ast import ASTnode, syntax_check_file
class SemData:
def __init__(self):
self.scope = None
self.root = None
self.global_symbol_table = {}
self.local_symbol_table = {}
def semantic_error(msg: str, node: ASTnode) -> None:
print(f'\033[31mSemantic Error: {msg} at line {node.lineno}\033[m')
raise SystemExit(1)
def print_todo(msg: str, node: ASTnode) -> None:
print(f'\033[33mTODO: {msg} at line {node.lineno}\033[m')
raise SystemExit(2)
def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
if sem_data.root is None:
sem_data.root = node
match node.nodetype:
case 'program':
# Collect function and procedure definitions first,
# since they can be called before they are defined
for child in node.children_definitions:
if child.nodetype in ['function_definition', 'procedure_definition']:
if child.value in sem_data.global_symbol_table:
semantic_error(f'Redefinition of {child.nodetype.split("_")[0]} \'{child.value}\'', child)
sem_data.global_symbol_table[child.value] = child
# Then do the actual semantic checking
for child in node.children_definitions:
semantic_check(child, sem_data)
for child in node.children_statements:
if semantic_check(child, sem_data) is not None:
semantic_error(f'Expression return value is not handled', child)
return None
case 'variable_definition':
# Check if variable is already defined
symbol_table = sem_data.global_symbol_table
if sem_data.scope is not None:
symbol_table = sem_data.local_symbol_table
if node.value in symbol_table:
semantic_error(f'Redefinition of variable \'{node.value}\'', node)
# Check if expression is valid and store it in symbol table
variable = semantic_check(node.child_expression, sem_data)
if variable is None or variable.type not in ['int', 'string', 'date']:
semantic_error(f'Invalid variable type \'{variable.type if variable is not None else None}\'', node)
symbol_table[node.value] = variable
return None
case 'function_definition' | 'procedure_definition':
# Function and procedures are added to global symbol table
# as the first step, so they can be called before they are defined
assert node.value in sem_data.global_symbol_table
# Local symbols table should be empty while doing checking,
# since functions and procedures can only be defined in global scope
assert len(sem_data.local_symbol_table) == 0 and sem_data.scope is None
sem_data.scope = node
# Collect local arguments
for formal in node.children_formals:
if formal.value in sem_data.local_symbol_table:
semantic_error(f'Redefinition of variable \'{formal.value}\' in {node.nodetype.split("_")[0]} \'{node.value}\' arguments', node)
sem_data.local_symbol_table[formal.value] = formal
# Collect local variables
for variable_definition in node.children_variable_definitions:
semantic_check(variable_definition, sem_data)
# Check return type
if node.nodetype == 'function_definition':
expression = semantic_check(node.child_expression, sem_data)
if expression is None:
semantic_error(f'Function \'{node.value}\' must return a value', node)
if node.child_return_type == 'auto':
node.child_return_type = expression.type
if expression.type != node.child_return_type:
semantic_error(f'Function \'{node.value}\' return type is {node.child_return_type} but returns {expression.type}', node)
elif node.nodetype == 'procedure_definition':
returns = None
for statement in node.children_statements:
returns = None
value = semantic_check(statement, sem_data)
if value is None:
continue
if value.nodetype != 'return':
semantic_error(f'Expression return value is not handled', statement)
if node.child_return_type is None:
semantic_error(f'Procedure \'{node.value}\' does not have a return type', node)
if node.child_return_type == 'auto':
node.child_return_type = value.type
if value.type != node.child_return_type:
semantic_error(f'Procedure \'{node.value}\' return type is {node.child_return_type} but returns {value.type}', node)
returns = value.type
if returns is None and node.child_return_type is not None:
if node.child_return_type != 'void':
semantic_error(f'Procedure \'{node.value}\' must return a value when scope exits', node)
else:
assert False
node.type = node.child_return_type
sem_data.scope = None
sem_data.local_symbol_table = {}
return None
case 'return':
if sem_data.scope is None or sem_data.scope.nodetype != 'procedure_definition':
semantic_error(f'Keyword \'return\' can only appear in procefure_definition')
result = semantic_check(node.child_expression, sem_data)
if result is None:
semantic_error(f'Procedure \'{sem_data.scope.value}\' must return a value', node)
node.type = result.type
return node
case 'date_literal' | 'int_literal' | 'string_literal':
node.type = node.nodetype.split('_')[0]
return node
case 'assignment':
lhs = semantic_check(node.child_lhs, sem_data)
rhs = semantic_check(node.child_rhs, sem_data)
if lhs is None or rhs is None or lhs.type != rhs.type:
semantic_error(f'Invalid assignment of \'{rhs.type if rhs is not None else None}\' to \'{lhs.type if lhs is not None else None}\'', node)
return None
case 'binary_op':
lhs = semantic_check(node.child_lhs, sem_data)
rhs = semantic_check(node.child_rhs, sem_data)
if lhs is None or rhs is None:
semantic_error(f'Invalid operands \'{lhs.type if lhs is not None else None}\' and \'{rhs.type if rhs is not None else None}\' for binary operation {node.value}', node)
# Validate operands and result type
if node.value in ['*', '/']:
if lhs.type == 'int' and rhs.type == 'int':
node.type = 'int'
return node
elif node.value == '+':
if lhs.type == 'date' and rhs.type == 'int':
node.type = 'date'
return node
if lhs.type == 'int' and rhs.type == 'int':
node.type = 'int'
return node
elif node.value == '-':
if lhs.type == 'date' and rhs.type == 'int':
node.type = 'date'
return node
if lhs.type == 'date' and rhs.type == 'date':
node.type = 'int'
return node
if lhs.type == 'int' and rhs.type == 'int':
node.type = 'int'
return node
elif node.value in ['<', '>', '=']:
if lhs.type == rhs.type:
node.type = 'bool'
return node
semantic_error(f'Invalid operands \'{lhs.type}\' and \'{rhs.type}\' for operation {node.value}', node)
case 'identifier':
# Check if variable is defined
symbol = None
if node.value in sem_data.local_symbol_table:
symbol = sem_data.local_symbol_table[node.value]
if node.value in sem_data.global_symbol_table:
symbol = sem_data.global_symbol_table[node.value]
if symbol is not None:
node.type = symbol.type
return symbol
semantic_error(f'Variable \'{node.value}\' not defined', node)
case 'function_call' | 'procedure_call':
# Handle built in functions
if node.nodetype == 'function_call' and node.value == 'Today':
if len(node.children_arguments) != 0:
semantic_error(f'Builtin function \'Today\' takes no arguments', node)
node.type = 'date'
return node
# Check if function/procedure is defined
if node.value not in sem_data.global_symbol_table:
semantic_error(f'{node.nodetype.split("_")[0]} \'{node.value}\' not defined', node)
func = sem_data.global_symbol_table[node.value]
# Check if arguments match (count and types)
if len(node.children_arguments) != len(func.children_formals):
semantic_error(f'Argument count mismatch for {node.nodetype.split("_")[0]} \'{node.value}\', expected {len(func.children_formals)} but got {len(node.children_arguments)}', node)
for formal, actual in zip(func.children_formals, node.children_arguments):
resolved = semantic_check(actual, sem_data)
if resolved is None or formal.type != resolved.type:
semantic_error(f'Argument type mismatch for {node.nodetype.split("_")[0]} \'{node.value}\', expected \'{formal.type}\' but got \'{resolved.type if resolved is not None else None}\'', node)
# Set return type and return node if func has a return type
node.type = func.child_return_type
return node if node.type is not None else None
case 'do_unless':
# Validate condition
condition = semantic_check(node.child_condition, sem_data)
if condition is None or condition.type != 'bool':
semantic_error('Condition must be of type \'bool\'', node)
# Validate both branches
for statement in node.children_statements_true:
if semantic_check(statement, sem_data) is not None:
semantic_error(f'Expression return value is not handled', statement)
for statement in node.children_statements_false:
if semantic_check(statement, sem_data) is not None:
semantic_error(f'Expression return value is not handled', statement)
return None
case 'do_until':
# Validate condition
condition = semantic_check(node.child_condition, sem_data)
if condition is None or condition.type != 'bool':
semantic_error('Condition must be of type bool', node)
# Validate body
for statement in node.children_statements:
if semantic_check(statement, sem_data) is not None:
semantic_error(f'Expression return value is not handled', statement)
return None
case 'unless_expression':
# Validate condition
condition = semantic_check(node.child_condition, sem_data)
if condition is None or condition.type != 'bool':
semantic_error('Condition must be of type bool', node)
# Validate both branches
expression_true = semantic_check(node.child_expression_true, sem_data)
expression_false = semantic_check(node.child_expression_false, sem_data)
if expression_true is None or expression_false is None or expression_true.type != expression_false.type:
semantic_error(f'Branches must return the same type, got \'{expression_false.type}\' and \'{expression_true.type}\'', node)
node.type = expression_true.type
return node
case 'attribute_read' | 'attribute_write':
# Check if variable is defined
symbol = None
if node.child_identifier.value in sem_data.local_symbol_table:
symbol = sem_data.local_symbol_table[node.child_identifier.value]
elif node.child_identifier.value in sem_data.global_symbol_table:
symbol = sem_data.global_symbol_table[node.child_identifier.value]
else:
semantic_error(f'Variable \'{node.child_identifier.value}\' not defined', node.child_identifier)
# Validate attribute
assert node.child_attribute.nodetype == 'identifier'
if symbol.type != 'date':
semantic_error(f'Cannot access attribute of non-date variable', node.child_attribute)
valid_attributes = ['day', 'month', 'year']
if node.nodetype == 'attribute_read':
valid_attributes += ['weekday', 'weeknum']
if node.child_attribute.value not in valid_attributes:
semantic_error(f'Invalid attribute \'{node.child_attribute.value}\' for {node.nodetype.split("_")[0]}, allowed values {valid_attributes}', node.child_attribute)
node.type = 'date'
return node
case 'print':
for item in node.children_items:
value = semantic_check(item, sem_data)
if value is None or value.type not in ['int', 'string', 'date']:
semantic_error('Print argument can only be \'int\', \'date\' or \'string\'', node)
return None
case _:
print_todo(f'Semantic check type \'{node.nodetype}\'', node)
def execute_ast(node: ASTnode, sem_data: SemData) -> None | ASTnode:
match node.nodetype:
case _:
print_todo(f'Execute type \'{node.nodetype}\'', node)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--debug', action='store_true', help='debug?')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--who', action='store_true', help='print out student IDs and NAMEs of authors')
group.add_argument('-f', '--file', help='filename to process')
args = parser.parse_args()
if args.who:
print('Author')
print(' Student ID: 150189237')
print(' Name: Oskari Alaranta')
else:
ast = syntax_check_file(args.file, args.debug)
#tree_print.treeprint(ast, 'unicode')
sem_data = SemData()
semantic_check(ast, sem_data)
execute_ast(ast, sem_data)

View File

@ -0,0 +1,209 @@
#!/usr/bin/env python3
# ----------------------------------------------------------------------
# Values to control the module's working
# How to recognize attributes in nodes by their names
child_prefix_default = "child_"
children_prefix_default = "children_"
value_attr = "value"
nodetype_attr = "nodetype"
lineno_attr = "lineno"
type_attr = "type"
# Finding and creating a list of all children nodes of a node, based on
# attribute names of a node
def get_childvars(node, child_prefix=child_prefix_default,
children_prefix=children_prefix_default):
'''Return all children nodes of a tree node
This function assumes that all attributes of a node beginning with
child_prefix refer to a child node, and attributes beginning with
children_prefix refer to a LIST of child nodes. The return value is a list
of pairs (tuples), where the first element of each pair is a "label"
for the node (the name of the attribute without the child/children prefix),
and the second element is the child node itself. For child lists, the label
also contains the number of the child, or EMPTY if the list is empty
(in which case None is used as the second element, as there is no child).'''
childvars = []
# Only search for attributes if we have an object
if hasattr(node, "__dict__"):
# Iterate though all attributes of the node object
for name,val in vars(node).items():
# An attribute containing one child node
if name.startswith(child_prefix):
label = name[len(child_prefix):]
childvars.append((label, val))
# An attribute containing a child list
elif name.startswith(children_prefix):
label = name[len(children_prefix):]
# Make sure contents is not None and is a list (or actually, can
# be iterated through
if val is None:
childvars.append((label+"[NONE stored instead of a list!!!]", None))
else:
if not hasattr(val, "__iter__"):
childvars.append((label+"[Not a list!!!]", None))
# An empty list/iterable (no nodes)
elif not val:
childvars.append((label+"[EMPTY]", None))
# A non-empty list/iterable
else:
childvars.extend([(label+"["+str(i)+"]", child) for (i, child) in enumerate(val)])
return childvars
# Printing the syntax tree (AST)
# Strings that ASCII and Unicode trees are made out of
vertical_uni = "\N{BOX DRAWINGS LIGHT VERTICAL}"
horizontal_uni = "\N{BOX DRAWINGS LIGHT HORIZONTAL}"
vertical_right_uni = "\N{BOX DRAWINGS LIGHT VERTICAL AND RIGHT}"
up_right_uni = "\N{BOX DRAWINGS LIGHT UP AND RIGHT}"
child_indent_uni = vertical_right_uni + horizontal_uni + horizontal_uni
last_child_indent_uni = up_right_uni + horizontal_uni + horizontal_uni
normal_indent_uni = vertical_uni + " "
last_normal_indent_uni = " "
vertical_asc = "|"
horizontal_asc = "-"
vertical_right_asc = "+"
up_right_asc = "+"
child_indent_asc = vertical_right_asc + horizontal_asc + horizontal_asc
last_child_indent_asc = up_right_asc + horizontal_asc + horizontal_asc
normal_indent_asc = vertical_asc + " "
last_normal_indent_asc = " "
# What to put to the beginning and end of dot files
dot_preamble='''digraph parsetree {
ratio=fill
node [shape="box"]
edge [style=bold]
ranksep=equally
nodesep=0.5
rankdir = TB
clusterrank = local'''
dot_postamble='}'
def dotnodeid(nodenum):
'''Convert node number to a dot id'''
return "N"+str(nodenum)
def treeprint_indent(node, outtype="unicode", label="", first_indent="", indent=""):
'''Print out an ASCII/Unicode version of a subtree in a tree.
node = the root of the subtree
outtype = unicode/ascii
label = the "role" of the subtree on the parent node (from attribute name)
first_indent = what to print at the beginning of the first line (indentation)
indent = what to print at the beginning of the rest of the lines (indentation)'''
# Add label (if any) to the first line after the indentation
if label:
first_indent += label + ": "
if not node:
# If node is None, just print NONE
print(first_indent + "NONE")
else:
# If node has node type attribute, print that, otherwise try to print the whole
# node take help in finding the error
if hasattr(node, nodetype_attr):
print(first_indent + getattr(node, nodetype_attr), end="")
else:
print(first_indent + "??? '" + str(node) + "' ???", end="")
# If node has a value attribute, print the value of the node in parenthesis
if hasattr(node, value_attr):
print(" (" + str(getattr(node, value_attr)) + ")", end="")
if hasattr(node, type_attr):
print(" :" + str(getattr(node, type_attr)), end="")
if hasattr(node, lineno_attr):
print(" #" + str(getattr(node, lineno_attr)), end="")
print()
# Get all children of the node and iterate through them
childvars = get_childvars(node)
i = len(childvars)
for name,value in childvars:
i -= 1
if i > 0:
# Not the last child, use normal indentation
if outtype == "unicode":
first_indent = child_indent_uni
rest_indent = normal_indent_uni
else:
first_indent = child_indent_asc
rest_indent = normal_indent_asc
else:
# The last child, use indentation for that case
if outtype == "unicode":
first_indent = last_child_indent_uni
rest_indent = last_normal_indent_uni
else:
first_indent = last_child_indent_asc
rest_indent = last_normal_indent_asc
# Recursively print the child subtrees, adding indentation
treeprint_indent(value, outtype, name, indent+first_indent,
indent+rest_indent)
def treeprint_dot(node, nodenum, nodecount):
'''Print a subtree in dot format.
nodenum = number of the node (for dot id generation)
nodecount = a list containing the maximum used id'''
nodeline = dotnodeid(nodenum)
if not node:
# None is output as an ellipse with label NONE
nodeline += ' [shape="ellipse", label="NONE"]'
print(nodeline)
else:
# Normal nodes use the default shape
nodeline += ' [label="'
# If node has node type attribute, print that, otherwise try to print the whole
# node take help in finding the error
if hasattr(node, nodetype_attr):
nodeline += getattr(node, nodetype_attr)
else:
nodeline += "??? '" + str(node) + "' ???"
nextnodeline = ""
# If node has a value attribute, output the value in parenthesis
if hasattr(node, value_attr):
nextnodeline += " (" + str(getattr(node, value_attr)) + ")"
if hasattr(node, type_attr):
nextnodeline += " :" + str(getattr(node, type_attr))
if hasattr(node, lineno_attr):
nextnodeline += " #" + str(getattr(node, lineno_attr))
if nextnodeline:
nodeline += "\n"+nextnodeline
nodeline += '"]'
print(nodeline)
# Get all children of the node and iterate through them
childvars = get_childvars(node)
for name,value in childvars:
# Number the child by one more than current maximum (and update maximum)
nodecount[0] += 1
childnum = nodecount[0]
# Recursively print the child subtrees
treeprint_dot(value, childnum, nodecount)
# Output the named connection between parent and child
print(dotnodeid(nodenum)+"->"+dotnodeid(childnum)+ ' [label="'+name+'"]')
def treeprint(rootnode, outtype="unicode"):
'''Prints out a tree, given its root.
The second argument is the output type:
"unicode" (default) prints a text-version of the tree using Unicode block characters.
"ascii" prints an ASCII-only version, with |, -, +.
"dot" prints a tree in dot format (can be converted to a graphical tree
using dot command in graphwiz).'''
if outtype == "dot":
print(dot_preamble)
treeprint_dot(rootnode, 0, [0])
print(dot_postamble)
else:
treeprint_indent(rootnode, outtype)