PoPL/04_semantics_and_running/main.py

679 lines
30 KiB
Python
Raw Normal View History

#!/bin/env python3
import argparse
from calendar import timegm
from copy import deepcopy
from datetime import date, timedelta
import subprocess
import tree_print
from build_ast import ASTnode, syntax_check_file
class SemData:
def __init__(self):
self.scope = None
self.root = None
self.callables = {}
self.global_symbol_table = {}
self.local_symbol_table = {}
def semantic_error(msg: str, node: ASTnode) -> None:
print(f'\033[31mSemantic Error: {msg} at line {node.lineno}\033[m')
raise SystemExit(1)
def print_todo(msg: str, node: ASTnode) -> None:
print(f'\033[33mTODO: {msg} at line {node.lineno}\033[m')
raise SystemExit(2)
def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
if sem_data.root is None:
sem_data.root = node
match node.nodetype:
case 'program':
# Collect function and procedure definitions first,
# since they can be called before they are defined
for child in node.children_definitions:
if child.nodetype in ['function_definition', 'procedure_definition']:
if child.value in sem_data.callables:
semantic_error(f'Redefinition of {child.nodetype.split("_")[0]} \'{child.value}\'', child)
sem_data.callables[child.value] = child
# Then do the actual semantic checking
for child in node.children_definitions:
semantic_check(child, sem_data)
for child in node.children_statements:
if semantic_check(child, sem_data) is not None:
semantic_error(f'Expression return value is not handled', child)
return None
case 'variable_definition':
# Check if variable is already defined
symbol_table = sem_data.global_symbol_table
if sem_data.scope is not None:
symbol_table = sem_data.local_symbol_table
if node.value in symbol_table:
semantic_error(f'Redefinition of variable \'{node.value}\'', node)
# Check if expression is valid and store it in symbol table
variable = semantic_check(node.child_expression, sem_data)
if variable is None or variable.type not in ['int', 'string', 'date']:
semantic_error(f'Invalid variable type \'{variable.type if variable is not None else None}\'', node)
symbol_table[node.value] = variable
return None
case 'function_definition' | 'procedure_definition':
# Function and procedures are added to global symbol table
# as the first step, so they can be called before they are defined
assert node.value in sem_data.callables
# Local symbols table should be empty while doing checking,
# since functions and procedures can only be defined in global scope
assert len(sem_data.local_symbol_table) == 0 and sem_data.scope is None
sem_data.scope = node
# Collect local arguments
for formal in node.children_formals:
if formal.value in sem_data.local_symbol_table:
semantic_error(f'Redefinition of variable \'{formal.value}\' in {node.nodetype.split("_")[0]} \'{node.value}\' arguments', node)
sem_data.local_symbol_table[formal.value] = formal
# Collect local variables
for variable_definition in node.children_variable_definitions:
semantic_check(variable_definition, sem_data)
# Check return type
if node.nodetype == 'function_definition':
expression = semantic_check(node.child_expression, sem_data)
if expression is None:
semantic_error(f'Function \'{node.value}\' must return a value', node)
if node.child_return_type == 'auto':
node.child_return_type = expression.type
if expression.type != node.child_return_type:
semantic_error(f'Function \'{node.value}\' return type is {node.child_return_type} but returns {expression.type}', node)
elif node.nodetype == 'procedure_definition':
returns = None
for statement in node.children_statements:
returns = None
value = semantic_check(statement, sem_data)
if value is None:
continue
if value.nodetype != 'return':
semantic_error(f'Expression return value is not handled', statement)
if node.child_return_type is None:
semantic_error(f'Procedure \'{node.value}\' does not have a return type', node)
if node.child_return_type == 'auto':
node.child_return_type = value.type
if value.type != node.child_return_type:
semantic_error(f'Procedure \'{node.value}\' return type is {node.child_return_type} but returns {value.type}', node)
returns = value.type
if returns is None and node.child_return_type is not None:
if node.child_return_type != 'void':
semantic_error(f'Procedure \'{node.value}\' must return a value when scope exits', node)
else:
assert False
node.type = node.child_return_type
sem_data.scope = None
sem_data.local_symbol_table = {}
return None
case 'return':
if sem_data.scope is None or sem_data.scope.nodetype != 'procedure_definition':
semantic_error(f'Keyword \'return\' can only appear in procefure_definition')
result = semantic_check(node.child_expression, sem_data)
if result is None:
semantic_error(f'Procedure \'{sem_data.scope.value}\' must return a value', node)
node.type = result.type
return node
case 'date_literal' | 'int_literal' | 'string_literal':
node.type = node.nodetype.split('_')[0]
return node
case 'assignment':
lhs = semantic_check(node.child_lhs, sem_data)
rhs = semantic_check(node.child_rhs, sem_data)
if lhs is None or rhs is None or lhs.type != rhs.type:
semantic_error(f'Invalid assignment of \'{rhs.type if rhs is not None else None}\' to \'{lhs.type if lhs is not None else None}\'', node)
return None
case 'binary_op':
lhs = semantic_check(node.child_lhs, sem_data)
rhs = semantic_check(node.child_rhs, sem_data)
if lhs is None or rhs is None:
semantic_error(f'Invalid operands \'{lhs.type if lhs is not None else None}\' and \'{rhs.type if rhs is not None else None}\' for binary operation {node.value}', node)
# Validate operands and result type
if node.value in ['*', '/']:
if lhs.type == 'int' and rhs.type == 'int':
node.type = 'int'
return node
elif node.value == '+':
if lhs.type == 'date' and rhs.type == 'int':
node.type = 'date'
return node
if lhs.type == 'int' and rhs.type == 'int':
node.type = 'int'
return node
elif node.value == '-':
if lhs.type == 'date' and rhs.type == 'int':
node.type = 'date'
return node
if lhs.type == 'date' and rhs.type == 'date':
node.type = 'int'
return node
if lhs.type == 'int' and rhs.type == 'int':
node.type = 'int'
return node
elif node.value in ['<', '=']:
if lhs.type == rhs.type:
node.type = 'bool'
return node
semantic_error(f'Invalid operands \'{lhs.type}\' and \'{rhs.type}\' for operation {node.value}', node)
case 'identifier':
# Check if variable is defined
symbol = None
if node.value in sem_data.local_symbol_table:
symbol = sem_data.local_symbol_table[node.value]
if node.value in sem_data.global_symbol_table:
symbol = sem_data.global_symbol_table[node.value]
if symbol is not None:
node.type = symbol.type
return symbol
semantic_error(f'Variable \'{node.value}\' not defined', node)
case 'function_call' | 'procedure_call':
# Handle built in functions
if node.nodetype == 'function_call' and node.value == 'Today':
if len(node.children_arguments) != 0:
semantic_error(f'Builtin function \'Today\' takes no arguments', node)
node.type = 'date'
return node
# Check if function/procedure is defined
if node.value not in sem_data.callables:
semantic_error(f'{node.nodetype.split("_")[0]} \'{node.value}\' not defined', node)
func = sem_data.callables[node.value]
# Check if arguments match (count and types)
if len(node.children_arguments) != len(func.children_formals):
semantic_error(f'Argument count mismatch for {node.nodetype.split("_")[0]} \'{node.value}\', expected {len(func.children_formals)} but got {len(node.children_arguments)}', node)
for formal, actual in zip(func.children_formals, node.children_arguments):
resolved = semantic_check(actual, sem_data)
if resolved is None or formal.type != resolved.type:
semantic_error(f'Argument type mismatch for {node.nodetype.split("_")[0]} \'{node.value}\', expected \'{formal.type}\' but got \'{resolved.type if resolved is not None else None}\'', node)
# Set return type and return node if func has a return type
node.type = func.child_return_type
return node if node.type is not None else None
case 'do_unless':
# Validate condition
condition = semantic_check(node.child_condition, sem_data)
if condition is None or condition.type != 'bool':
semantic_error('Condition must be of type \'bool\'', node)
# Validate both branches
for statement in node.children_statements_true:
if semantic_check(statement, sem_data) is not None:
semantic_error(f'Expression return value is not handled', statement)
for statement in node.children_statements_false:
if semantic_check(statement, sem_data) is not None:
semantic_error(f'Expression return value is not handled', statement)
return None
case 'do_until':
# Validate condition
condition = semantic_check(node.child_condition, sem_data)
if condition is None or condition.type != 'bool':
semantic_error('Condition must be of type bool', node)
# Validate body
for statement in node.children_statements:
if semantic_check(statement, sem_data) is not None:
semantic_error(f'Expression return value is not handled', statement)
return None
case 'unless_expression':
# Validate condition
condition = semantic_check(node.child_condition, sem_data)
if condition is None or condition.type != 'bool':
semantic_error('Condition must be of type bool', node)
# Validate both branches
expression_true = semantic_check(node.child_expression_true, sem_data)
expression_false = semantic_check(node.child_expression_false, sem_data)
if expression_true is None or expression_false is None or expression_true.type != expression_false.type:
semantic_error(f'Branches must return the same type, got \'{expression_false.type}\' and \'{expression_true.type}\'', node)
node.type = expression_true.type
return node
case 'attribute_read' | 'attribute_write':
# Check if variable is defined
symbol = None
if node.child_identifier.value in sem_data.local_symbol_table:
symbol = sem_data.local_symbol_table[node.child_identifier.value]
elif node.child_identifier.value in sem_data.global_symbol_table:
symbol = sem_data.global_symbol_table[node.child_identifier.value]
else:
semantic_error(f'Variable \'{node.child_identifier.value}\' not defined', node.child_identifier)
# Validate attribute
assert node.child_attribute.nodetype == 'identifier'
if symbol.type != 'date':
semantic_error(f'Cannot access attribute of non-date variable', node.child_attribute)
valid_attributes = ['day', 'month', 'year']
if node.nodetype == 'attribute_read':
valid_attributes += ['weekday', 'weeknum']
if node.child_attribute.value not in valid_attributes:
semantic_error(f'Invalid attribute \'{node.child_attribute.value}\' for {node.nodetype.split("_")[0]}, allowed values {valid_attributes}', node.child_attribute)
node.type = 'int'
return node
case 'print':
for item in node.children_items:
value = semantic_check(item, sem_data)
if value is None or value.type not in ['int', 'string', 'date']:
semantic_error('Print argument can only be \'int\', \'date\' or \'string\'', node)
return None
case _:
print_todo(f'Semantic check type \'{node.nodetype}\'', node)
class CompileData:
def __init__(self, sem_data: SemData):
self.sem_data = sem_data
self.date_buffer_size = 128
self.string_literals = []
self.label_counter = 0
self.callables = {}
self.scope = None
self.code = ''
self.callables['__builtin_today'] = '''\
pushq %rbp
movq %rsp, %rbp
xorq %rdi, %rdi
call time
movq %rax, %rdi
movq $86400, %rcx
xorq %rdx, %rdx
divq %rcx
movq %rdi, %rax
subq %rdx, %rax
popq %rbp
ret
'''
self.callables['__builtin_print_date'] = f'''\
pushq %rbp
movq %rsp, %rbp
subq $16, %rsp
movq %rdi, 0(%rsp)
leaq 0(%rsp), %rdi
call localtime
movq $.date_buffer, %rdi
movq ${self.date_buffer_size}, %rsi
movq $.date_format, %rdx
movq %rax, %rcx
call strftime
movq $.str_format, %rdi
movq $.date_buffer, %rsi
call printf
leave
ret
'''
self.callables['__builtin_get_day_attr'] = f'''\
pushq %rbp
movq %rsp, %rbp
subq $16, %rsp
movq %rdi, 0(%rsp)
movq %rsi, 8(%rsp)
leaq 0(%rsp), %rdi
call localtime
movq $.date_buffer, %rdi
movq ${self.date_buffer_size}, %rsi
movq 8(%rsp), %rdx
movq %rax, %rcx
call strftime
movq $.date_buffer, %rdi
call atoi
leave
ret
'''
def get_label(self) -> str:
self.label_counter += 1
return f'.L{self.label_counter - 1}'
def insert_label(self, label) -> None:
self.code += f'{label}:\n'
def add_string_literal(self, value: str) -> str:
for index, string in enumerate(self.string_literals):
if string == value:
return f'S{index}'
self.string_literals.append(value)
return f'S{len(self.string_literals) - 1}'
def symbol_address(self, symbol: str) -> str:
if self.scope is not None:
for index, formal in enumerate(self.scope.children_formals):
if formal.value == symbol:
offset = 8 * index + 16
return f'{offset}(%rbp)'
for index, variable in enumerate(self.scope.children_variable_definitions):
if variable.value == symbol:
offset = 8 * index + 8
return f'-{offset}(%rbp)'
if symbol in self.sem_data.global_symbol_table:
offset = 8 * list(self.sem_data.global_symbol_table.keys()).index(symbol)
return f'(.globals + {offset})'
assert False
def get_full_code(self) -> str:
# Data section with string literals
prefix = '.section .data\n'
prefix += '.int_format: .asciz "%lld"\n'
prefix += '.str_format: .asciz "%s"\n'
prefix += '.date_format: .asciz "%Y-%m-%d"\n'
prefix += '.day_format: .asciz "%d"\n'
prefix += '.month_format: .asciz "%m"\n'
prefix += '.year_format: .asciz "%Y"\n'
prefix += '.weekday_format: .asciz "%u"\n'
prefix += '.weeknum_format: .asciz "%W"\n'
for index, string in enumerate(self.string_literals):
prefix += f'S{index}: .asciz "{string}"\n'
prefix += '\n'
# BSS section for uninitialized data
prefix += f'.section .bss\n'
prefix += f'.date_buffer:\n'
prefix += f' .skip {self.date_buffer_size}\n'
if len(self.sem_data.global_symbol_table) != 0:
prefix += '.globals:\n'
prefix += f' .skip {len(sem_data.global_symbol_table) * 8}\n'
prefix += '\n'
# Text section with code
prefix += '.section .text\n'
prefix += '\n'
# Add function and procedure definitions
for name, code in self.callables.items():
prefix += name + ':\n'
prefix += code
prefix += '\n'
prefix += '.global main\n'
prefix += 'main:\n'
prefix += ' pushq %rbp\n'
prefix += ' movq %rsp, %rbp\n'
postfix = ' xorq %rax, %rax\n'
postfix += ' leave\n'
postfix += ' ret\n'
return prefix + self.code + postfix
def compile_print_literal(print_type: str, compile_data: CompileData) -> None:
if print_type == 'int':
compile_data.code += f' movq $.int_format, %rdi\n'
compile_data.code += f' movq %rax, %rsi\n'
compile_data.code += f' call printf\n'
elif print_type == 'string':
compile_data.code += f' movq $.str_format, %rdi\n'
compile_data.code += f' movq %rax, %rsi\n'
compile_data.code += f' call printf\n'
elif print_type == 'date':
compile_data.code += f' movq %rax, %rdi\n'
compile_data.code += f' call __builtin_print_date\n'
else:
assert False
def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
match node.nodetype:
case 'program':
# Compile function and procedure definitions
for definition in node.children_definitions:
if definition.nodetype not in ['function_definition', 'procedure_definition']:
continue
assert len(compile_data.code) == 0
assert compile_data.scope is None
compile_data.scope = definition
# initialize stack frame
compile_data.code += ' pushq %rbp\n'
compile_data.code += ' movq %rsp, %rbp\n'
# initialize local variables
stack_size = 8 * len(definition.children_variable_definitions)
if stack_size % 16 != 0:
stack_size += 8
if stack_size != 0:
compile_data.code += f' subq ${stack_size}, %rsp\n'
for variable in definition.children_variable_definitions:
address = compile_data.symbol_address(variable.value)
compile_ast(variable.child_expression, compile_data)
compile_data.code += f' movq %rax, {address}\n'
# compile statements
if definition.nodetype == 'function_definition':
compile_ast(definition.child_expression, compile_data)
elif definition.nodetype == 'procedure_definition':
for statement in definition.children_statements:
compile_ast(statement, compile_data)
else: assert False
# return from procedure
compile_data.code += f' leave\n'
compile_data.code += f' ret\n'
compile_data.callables[definition.value] = compile_data.code
compile_data.code = ''
compile_data.scope = None
# Initialize global variables
for index, (name, variable) in enumerate(compile_data.sem_data.global_symbol_table.items()):
address = compile_data.symbol_address(name)
compile_ast(variable, compile_data)
compile_data.code += f' movq %rax, {address}\n'
# Compile program statements
for statement in node.children_statements:
compile_ast(statement, compile_data)
case 'variable_definition' | 'function_definition' | 'procedure_definition':
assert False
case 'identifier':
address = compile_data.symbol_address(node.value)
compile_data.code += f' movq {address}, %rax\n'
case 'assignment':
if node.child_lhs.nodetype == 'attribute_write':
print_todo('Attribute write', node)
elif node.child_lhs.nodetype == 'identifier':
address = compile_data.symbol_address(node.child_lhs.value)
compile_ast(node.child_rhs, compile_data)
compile_data.code += f' movq %rax, {address}\n'
else: assert False
case 'binary_op':
assert node.value in ['+', '-', '*', '/', '<', '=']
if node.value in ['*', '/']:
assert node.child_lhs.type == 'int'
else:
assert node.child_lhs.type in ['int', 'date']
if node.value == '-' and node.child_lhs.type == 'date':
assert node.child_rhs.type in ['int', 'date']
else:
assert node.child_rhs.type == 'int'
# calculate LHS and store it on stack
compile_data.code += f' subq $16, %rsp\n'
compile_ast(node.child_lhs, compile_data)
compile_data.code += f' movq %rax, 0(%rsp)\n'
# calculate RHS and store it in RCX
compile_ast(node.child_rhs, compile_data)
if node.child_lhs.type == 'date' and node.child_rhs.type == 'int':
# multiply RHS by number of seconds in a day so we can perform arithmetic on dates
compile_data.code += f' imulq $86400, %rax, %rcx\n'
else:
compile_data.code += f' movq %rax, %rcx\n'
# prepare registers, RAX contains LHS and RCX contains RHS
# and restore restore stack
compile_data.code += f' movq 0(%rsp), %rax\n'
compile_data.code += f' addq $16, %rsp\n'
# perform operation
if node.value == '+':
compile_data.code += f' addq %rcx, %rax\n'
elif node.value == '-':
compile_data.code += f' subq %rcx, %rax\n'
elif node.value == '*':
compile_data.code += f' imulq %rcx, %rax\n'
elif node.value == '/':
compile_data.code += f' cqo\n'
compile_data.code += f' idivq %rcx\n'
elif node.value == '<':
compile_data.code += f' cmpq %rcx, %rax\n'
compile_data.code += f' setl %al\n'
compile_data.code += f' movzbq %al, %rax\n'
elif node.value == '=':
compile_data.code += f' cmpq %rcx, %rax\n'
compile_data.code += f' sete %al\n'
compile_data.code += f' movzbq %al, %rax\n'
else: assert False
# if both operands are dates, divide result by number of seconds in a day
if node.child_lhs.type == 'date' and node.child_rhs.type == 'date':
assert node.value == '-'
compile_data.code += f' movq $86400, %rcx\n'
compile_data.code += f' cqo\n'
compile_data.code += f' idivq %rcx\n'
case 'function_call' | 'procedure_call':
if node.value == 'Today':
compile_data.code += f' call __builtin_today\n'
else:
# align stack
stack_needed = len(node.children_arguments) * 8
if stack_needed % 16 != 0:
stack_needed += 8
if stack_needed != 0:
compile_data.code += f' subq ${stack_needed}, %rsp\n'
# push arguments to the stack
offset = 0
for argument in node.children_arguments:
compile_ast(argument, compile_data)
compile_data.code += f' movq %rax, {offset}(%rsp)\n'
offset += 8
# call function and restore stack
compile_data.code += f' call {node.value}\n'
if stack_needed != 0:
compile_data.code += f' addq ${stack_needed}, %rsp\n'
case 'return':
compile_ast(node.child_expression, compile_data)
compile_data.code += f' leave\n'
compile_data.code += f' ret\n'
case 'int_literal':
compile_data.code += f' movq ${node.value}, %rax\n'
case 'string_literal':
label = compile_data.add_string_literal(node.value)
compile_data.code += f' movq ${label}, %rax\n'
case 'date_literal':
compile_data.code += f' movq ${timegm(node.value.timetuple())}, %rax\n'
case 'attribute_read':
compile_ast(node.child_identifier, compile_data)
compile_data.code += f' movq %rax, %rdi\n'
compile_data.code += f' movq $.{node.child_attribute.value}_format, %rsi\n'
compile_data.code += f' call __builtin_get_day_attr\n'
case 'do_until':
label_loop = compile_data.get_label()
compile_data.insert_label(label_loop)
# compile statements
for statement in node.children_statements:
compile_ast(statement, compile_data)
# compile condition
compile_ast(node.child_condition, compile_data)
compile_data.code += f' testq %rax, %rax\n'
compile_data.code += f' jz {label_loop}\n'
case 'do_unless' | 'unless_expression':
label_true = compile_data.get_label()
label_done = compile_data.get_label()
# compile condition
compile_ast(node.child_condition, compile_data)
compile_data.code += f' testq %rax, %rax\n'
compile_data.code += f' jnz {label_true}\n'
# compile false statements
if node.nodetype == 'unless_expression':
compile_ast(node.child_expression_false, compile_data)
elif node.nodetype == 'do_unless':
for statement in node.children_statements_false:
compile_ast(statement, compile_data)
else: assert False
compile_data.code += f' jmp {label_done}\n'
# compile true statements
compile_data.insert_label(label_true)
if node.nodetype == 'unless_expression':
compile_ast(node.child_expression_true, compile_data)
elif node.nodetype == 'do_unless':
for statement in node.children_statements_true:
compile_ast(statement, compile_data)
else: assert False
# add label for done
compile_data.insert_label(label_done)
case 'print':
for i, item in enumerate(node.children_items):
assert item.type in ['int', 'string', 'date']
compile_ast(item, compile_data)
compile_print_literal(item.type, compile_data)
# Print space if there are more items
if i < len(node.children_items) - 1:
compile_data.code += f' movl $\' \', %edi\n'
compile_data.code += f' call putchar\n'
# Print newline
compile_data.code += f' movl $\'\\n\', %edi\n'
compile_data.code += f' call putchar\n'
case _:
print_todo(f'Compile type \'{node.nodetype}\'', node)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--debug', action='store_true', help='debug?')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--who', action='store_true', help='print out student IDs and NAMEs of authors')
group.add_argument('-f', '--file', help='filename to process')
parser.add_argument('-o', '--compile', help='output filename for compiled code. if not given, no compilation is done')
parser.add_argument('-a', '--assembly', help='output filename for generated assembly code')
args = parser.parse_args()
if args.who:
print('Author')
print(' Student ID: 150189237')
print(' Name: Oskari Alaranta')
else:
ast = syntax_check_file(args.file, args.debug)
sem_data = SemData()
semantic_check(ast, sem_data)
if args.debug:
tree_print.treeprint(ast, 'unicode')
compile_data = CompileData(sem_data)
compile_ast(ast, compile_data)
assembly = compile_data.get_full_code()
if args.assembly is None:
if args.compile is None:
print(assembly)
else:
with open(args.assembly, 'w', encoding='utf-8') as file:
file.write(assembly)
if args.compile is not None:
subprocess.run(['gcc', '-x', 'assembler', '-o', args.compile, '-static', '-'], input=assembly, encoding='utf-8')