#!/bin/env python3 import argparse from calendar import timegm from copy import deepcopy from datetime import date, timedelta import tree_print from build_ast import ASTnode, syntax_check_file class SemData: def __init__(self): self.scope = None self.root = None self.callables = {} self.global_symbol_table = {} self.local_symbol_table = {} def semantic_error(msg: str, node: ASTnode) -> None: print(f'\033[31mSemantic Error: {msg} at line {node.lineno}\033[m') raise SystemExit(1) def print_todo(msg: str, node: ASTnode) -> None: print(f'\033[33mTODO: {msg} at line {node.lineno}\033[m') raise SystemExit(2) def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode: if sem_data.root is None: sem_data.root = node match node.nodetype: case 'program': # Collect function and procedure definitions first, # since they can be called before they are defined for child in node.children_definitions: if child.nodetype in ['function_definition', 'procedure_definition']: if child.value in sem_data.callables: semantic_error(f'Redefinition of {child.nodetype.split("_")[0]} \'{child.value}\'', child) sem_data.callables[child.value] = child # Then do the actual semantic checking for child in node.children_definitions: semantic_check(child, sem_data) for child in node.children_statements: if semantic_check(child, sem_data) is not None: semantic_error(f'Expression return value is not handled', child) return None case 'variable_definition': # Check if variable is already defined symbol_table = sem_data.global_symbol_table if sem_data.scope is not None: symbol_table = sem_data.local_symbol_table if node.value in symbol_table: semantic_error(f'Redefinition of variable \'{node.value}\'', node) # Check if expression is valid and store it in symbol table variable = semantic_check(node.child_expression, sem_data) if variable is None or variable.type not in ['int', 'string', 'date']: semantic_error(f'Invalid variable type \'{variable.type if variable is not None else None}\'', node) symbol_table[node.value] = variable return None case 'function_definition' | 'procedure_definition': # Function and procedures are added to global symbol table # as the first step, so they can be called before they are defined assert node.value in sem_data.callables # Local symbols table should be empty while doing checking, # since functions and procedures can only be defined in global scope assert len(sem_data.local_symbol_table) == 0 and sem_data.scope is None sem_data.scope = node # Collect local arguments for formal in node.children_formals: if formal.value in sem_data.local_symbol_table: semantic_error(f'Redefinition of variable \'{formal.value}\' in {node.nodetype.split("_")[0]} \'{node.value}\' arguments', node) sem_data.local_symbol_table[formal.value] = formal # Collect local variables for variable_definition in node.children_variable_definitions: semantic_check(variable_definition, sem_data) # Check return type if node.nodetype == 'function_definition': expression = semantic_check(node.child_expression, sem_data) if expression is None: semantic_error(f'Function \'{node.value}\' must return a value', node) if node.child_return_type == 'auto': node.child_return_type = expression.type if expression.type != node.child_return_type: semantic_error(f'Function \'{node.value}\' return type is {node.child_return_type} but returns {expression.type}', node) elif node.nodetype == 'procedure_definition': returns = None for statement in node.children_statements: returns = None value = semantic_check(statement, sem_data) if value is None: continue if value.nodetype != 'return': semantic_error(f'Expression return value is not handled', statement) if node.child_return_type is None: semantic_error(f'Procedure \'{node.value}\' does not have a return type', node) if node.child_return_type == 'auto': node.child_return_type = value.type if value.type != node.child_return_type: semantic_error(f'Procedure \'{node.value}\' return type is {node.child_return_type} but returns {value.type}', node) returns = value.type if returns is None and node.child_return_type is not None: if node.child_return_type != 'void': semantic_error(f'Procedure \'{node.value}\' must return a value when scope exits', node) else: assert False node.type = node.child_return_type sem_data.scope = None sem_data.local_symbol_table = {} return None case 'return': if sem_data.scope is None or sem_data.scope.nodetype != 'procedure_definition': semantic_error(f'Keyword \'return\' can only appear in procefure_definition') result = semantic_check(node.child_expression, sem_data) if result is None: semantic_error(f'Procedure \'{sem_data.scope.value}\' must return a value', node) node.type = result.type return node case 'date_literal' | 'int_literal' | 'string_literal': node.type = node.nodetype.split('_')[0] return node case 'assignment': lhs = semantic_check(node.child_lhs, sem_data) rhs = semantic_check(node.child_rhs, sem_data) if lhs is None or rhs is None or lhs.type != rhs.type: semantic_error(f'Invalid assignment of \'{rhs.type if rhs is not None else None}\' to \'{lhs.type if lhs is not None else None}\'', node) return None case 'binary_op': lhs = semantic_check(node.child_lhs, sem_data) rhs = semantic_check(node.child_rhs, sem_data) if lhs is None or rhs is None: semantic_error(f'Invalid operands \'{lhs.type if lhs is not None else None}\' and \'{rhs.type if rhs is not None else None}\' for binary operation {node.value}', node) # Validate operands and result type if node.value in ['*', '/']: if lhs.type == 'int' and rhs.type == 'int': node.type = 'int' return node elif node.value == '+': if lhs.type == 'date' and rhs.type == 'int': node.type = 'date' return node if lhs.type == 'int' and rhs.type == 'int': node.type = 'int' return node elif node.value == '-': if lhs.type == 'date' and rhs.type == 'int': node.type = 'date' return node if lhs.type == 'date' and rhs.type == 'date': node.type = 'int' return node if lhs.type == 'int' and rhs.type == 'int': node.type = 'int' return node elif node.value in ['<', '=']: if lhs.type == rhs.type: node.type = 'bool' return node semantic_error(f'Invalid operands \'{lhs.type}\' and \'{rhs.type}\' for operation {node.value}', node) case 'identifier': # Check if variable is defined symbol = None if node.value in sem_data.local_symbol_table: symbol = sem_data.local_symbol_table[node.value] if node.value in sem_data.global_symbol_table: symbol = sem_data.global_symbol_table[node.value] if symbol is not None: node.type = symbol.type return symbol semantic_error(f'Variable \'{node.value}\' not defined', node) case 'function_call' | 'procedure_call': # Handle built in functions if node.nodetype == 'function_call' and node.value == 'Today': if len(node.children_arguments) != 0: semantic_error(f'Builtin function \'Today\' takes no arguments', node) node.type = 'date' return node # Check if function/procedure is defined if node.value not in sem_data.callables: semantic_error(f'{node.nodetype.split("_")[0]} \'{node.value}\' not defined', node) func = sem_data.callables[node.value] # Check if arguments match (count and types) if len(node.children_arguments) != len(func.children_formals): semantic_error(f'Argument count mismatch for {node.nodetype.split("_")[0]} \'{node.value}\', expected {len(func.children_formals)} but got {len(node.children_arguments)}', node) for formal, actual in zip(func.children_formals, node.children_arguments): resolved = semantic_check(actual, sem_data) if resolved is None or formal.type != resolved.type: semantic_error(f'Argument type mismatch for {node.nodetype.split("_")[0]} \'{node.value}\', expected \'{formal.type}\' but got \'{resolved.type if resolved is not None else None}\'', node) # Set return type and return node if func has a return type node.type = func.child_return_type return node if node.type is not None else None case 'do_unless': # Validate condition condition = semantic_check(node.child_condition, sem_data) if condition is None or condition.type != 'bool': semantic_error('Condition must be of type \'bool\'', node) # Validate both branches for statement in node.children_statements_true: if semantic_check(statement, sem_data) is not None: semantic_error(f'Expression return value is not handled', statement) for statement in node.children_statements_false: if semantic_check(statement, sem_data) is not None: semantic_error(f'Expression return value is not handled', statement) return None case 'do_until': # Validate condition condition = semantic_check(node.child_condition, sem_data) if condition is None or condition.type != 'bool': semantic_error('Condition must be of type bool', node) # Validate body for statement in node.children_statements: if semantic_check(statement, sem_data) is not None: semantic_error(f'Expression return value is not handled', statement) return None case 'unless_expression': # Validate condition condition = semantic_check(node.child_condition, sem_data) if condition is None or condition.type != 'bool': semantic_error('Condition must be of type bool', node) # Validate both branches expression_true = semantic_check(node.child_expression_true, sem_data) expression_false = semantic_check(node.child_expression_false, sem_data) if expression_true is None or expression_false is None or expression_true.type != expression_false.type: semantic_error(f'Branches must return the same type, got \'{expression_false.type}\' and \'{expression_true.type}\'', node) node.type = expression_true.type return node case 'attribute_read' | 'attribute_write': # Check if variable is defined symbol = None if node.child_identifier.value in sem_data.local_symbol_table: symbol = sem_data.local_symbol_table[node.child_identifier.value] elif node.child_identifier.value in sem_data.global_symbol_table: symbol = sem_data.global_symbol_table[node.child_identifier.value] else: semantic_error(f'Variable \'{node.child_identifier.value}\' not defined', node.child_identifier) # Validate attribute assert node.child_attribute.nodetype == 'identifier' if symbol.type != 'date': semantic_error(f'Cannot access attribute of non-date variable', node.child_attribute) valid_attributes = ['day', 'month', 'year'] if node.nodetype == 'attribute_read': valid_attributes += ['weekday', 'weeknum'] if node.child_attribute.value not in valid_attributes: semantic_error(f'Invalid attribute \'{node.child_attribute.value}\' for {node.nodetype.split("_")[0]}, allowed values {valid_attributes}', node.child_attribute) node.type = 'int' return node case 'print': for item in node.children_items: value = semantic_check(item, sem_data) if value is None or value.type not in ['int', 'string', 'date']: semantic_error('Print argument can only be \'int\', \'date\' or \'string\'', node) return None case _: print_todo(f'Semantic check type \'{node.nodetype}\'', node) class CompileData: def __init__(self, sem_data: SemData): self.sem_data = sem_data self.string_literals = [] self.label_counter = 0 self.callables = {} self.scope = None self.code = '' self.callables['__builtin_today'] = '''\ pushq %rbp movq %rsp, %rbp xorq %rdi, %rdi call time movq %rax, %rdi movq $86400, %rcx xorq %rdx, %rdx divq %rcx movq %rdi, %rax subq %rdx, %rax popq %rbp ret ''' self.callables['__builtin_print_date'] = '''\ pushq %rbp movq %rsp, %rbp subq $16, %rsp movq %rdi, 0(%rsp) leaq 0(%rsp), %rdi call gmtime movq $date_buffer, %rdi movq $(date_buffer_end - date_buffer), %rsi movq $date_format, %rdx movq %rax, %rcx call strftime movq $str_format, %rdi movq $date_buffer, %rsi call printf leave ret ''' def get_label(self) -> str: self.label_counter += 1 return f'.L{self.label_counter - 1}' def insert_label(self, label) -> None: self.code += f'{label}:\n' def add_string_literal(self, value: str) -> str: for index, string in enumerate(self.string_literals): if string == value: return f'S{index}' self.string_literals.append(value) return f'S{len(self.string_literals) - 1}' def symbol_address(self, symbol: str) -> str: if self.scope is not None: for index, formal in enumerate(self.scope.children_formals): if formal.value == symbol: offset = 8 * index + 16 return f'{offset}(%rbp)' for index, variable in enumerate(self.scope.children_variable_definitions): if variable.value == symbol: offset = 8 * index + 8 return f'-{offset}(%rbp)' if symbol in self.sem_data.global_symbol_table: offset = 8 * list(self.sem_data.global_symbol_table.keys()).index(symbol) return f'(globals + {offset})' assert False def get_full_code(self) -> str: # Data section with string literals prefix = '.section .data\n' prefix += 'int_format: .asciz "%lld"\n' prefix += 'str_format: .asciz "%s"\n' prefix += 'date_format: .asciz "%Y-%m-%d"\n' for index, string in enumerate(self.string_literals): prefix += f'S{index}: .asciz "{string}"\n' prefix += '\n' # BSS section for uninitialized data prefix += '.section .bss\n' prefix += 'date_buffer:\n' prefix += ' .skip 128\n' prefix += 'date_buffer_end:\n' if len(self.sem_data.global_symbol_table) != 0: prefix += 'globals:\n' prefix += f' .skip {len(sem_data.global_symbol_table) * 8}\n' prefix += '\n' # Text section with code prefix += '.section .text\n' prefix += '\n' # Add function and procedure definitions for name, code in self.callables.items(): prefix += name + ':\n' prefix += code prefix += '\n' prefix += '.global main\n' prefix += 'main:\n' prefix += ' pushq %rbp\n' prefix += ' movq %rsp, %rbp\n' postfix = ' xorq %rax, %rax\n' postfix += ' leave\n' postfix += ' ret\n' return prefix + self.code + postfix def compile_print_literal(print_type: str, compile_data: CompileData) -> None: if print_type == 'int': compile_data.code += f' movq $int_format, %rdi\n' compile_data.code += f' movq %rax, %rsi\n' compile_data.code += f' call printf\n' elif print_type == 'string': compile_data.code += f' movq $str_format, %rdi\n' compile_data.code += f' movq %rax, %rsi\n' compile_data.code += f' call printf\n' elif print_type == 'date': compile_data.code += f' movq %rax, %rdi\n' compile_data.code += f' call __builtin_print_date\n' else: assert False def compile_ast(node: ASTnode, compile_data: CompileData) -> None: match node.nodetype: case 'program': # Compile function and procedure definitions for definition in node.children_definitions: if definition.nodetype not in ['function_definition', 'procedure_definition']: continue assert len(compile_data.code) == 0 assert compile_data.scope is None compile_data.scope = definition # initialize stack frame compile_data.code += ' pushq %rbp\n' compile_data.code += ' movq %rsp, %rbp\n' # initialize local variables stack_size = 8 * len(definition.children_variable_definitions) if stack_size % 16 != 0: stack_size += 8 if stack_size != 0: compile_data.code += f' subq ${stack_size}, %rsp\n' for variable in definition.children_variable_definitions: address = compile_data.symbol_address(variable.value) compile_ast(variable.child_expression, compile_data) compile_data.code += f' movq %rax, {address}\n' # compile statements if definition.nodetype == 'function_definition': compile_ast(definition.child_expression, compile_data) elif definition.nodetype == 'procedure_definition': for statement in definition.children_statements: compile_ast(statement, compile_data) else: assert False # return from procedure compile_data.code += f' leave\n' compile_data.code += f' ret\n' compile_data.callables[definition.value] = compile_data.code compile_data.code = '' compile_data.scope = None # Initialize global variables for index, (name, variable) in enumerate(compile_data.sem_data.global_symbol_table.items()): offset = 8 * index compile_ast(variable, compile_data) compile_data.code += f' movq %rax, (globals + {offset})\n' # Compile program statements for statement in node.children_statements: compile_ast(statement, compile_data) case 'variable_definition' | 'function_definition' | 'procedure_definition': assert False case 'identifier': address = compile_data.symbol_address(node.value) compile_data.code += f' movq {address}, %rax\n' case 'assignment': address = compile_data.symbol_address(node.child_lhs.value) compile_ast(node.child_rhs, compile_data) compile_data.code += f' movq %rax, {address}\n' case 'binary_op': assert node.value in ['+', '-', '*', '/', '<', '='] if node.value in ['*', '/']: assert node.child_lhs.type == 'int' else: assert node.child_lhs.type in ['int', 'date'] if node.value == '-' and node.child_lhs.type == 'date': assert node.child_rhs.type in ['int', 'date'] else: assert node.child_rhs.type == 'int' # calculate LHS and store it on stack compile_data.code += f' subq $16, %rsp\n' compile_ast(node.child_lhs, compile_data) compile_data.code += f' movq %rax, 0(%rsp)\n' # calculate RHS and store it in RCX compile_ast(node.child_rhs, compile_data) if node.child_lhs.type == 'date' and node.child_rhs.type == 'int': # multiply RHS by number of seconds in a day so we can perform arithmetic on dates compile_data.code += f' imulq $86400, %rax, %rcx\n' else: compile_data.code += f' movq %rax, %rcx\n' # prepare registers, RAX contains LHS and RCX contains RHS # and restore restore stack compile_data.code += f' movq 0(%rsp), %rax\n' compile_data.code += f' addq $16, %rsp\n' # perform operation if node.value == '+': compile_data.code += f' addq %rcx, %rax\n' elif node.value == '-': compile_data.code += f' subq %rcx, %rax\n' elif node.value == '*': compile_data.code += f' imulq %rcx, %rax\n' elif node.value == '/': compile_data.code += f' cqo\n' compile_data.code += f' idivq %rcx\n' elif node.value == '<': compile_data.code += f' cmpq %rcx, %rax\n' compile_data.code += f' setl %al\n' compile_data.code += f' movzbq %al, %rax\n' elif node.value == '=': compile_data.code += f' cmpq %rcx, %rax\n' compile_data.code += f' sete %al\n' compile_data.code += f' movzbq %al, %rax\n' else: assert False # if both operands are dates, divide result by number of seconds in a day if node.child_lhs.type == 'date' and node.child_rhs.type == 'date': assert node.value == '-' compile_data.code += f' movq $86400, %rcx\n' compile_data.code += f' cqo\n' compile_data.code += f' idivq %rcx\n' case 'function_call' | 'procedure_call': if node.value == 'Today': compile_data.code += f' call __builtin_today\n' else: # align stack stack_needed = len(node.children_arguments) * 8 if stack_needed % 16 != 0: stack_needed += 8 if stack_needed != 0: compile_data.code += f' subq ${stack_needed}, %rsp\n' # push arguments to the stack offset = 0 for argument in node.children_arguments: compile_ast(argument, compile_data) compile_data.code += f' movq %rax, {offset}(%rsp)\n' offset += 8 # call function and restore stack compile_data.code += f' call {node.value}\n' if stack_needed != 0: compile_data.code += f' addq ${stack_needed}, %rsp\n' case 'return': compile_ast(node.child_expression, compile_data) compile_data.code += f' leave\n' compile_data.code += f' ret\n' case 'int_literal': compile_data.code += f' movq ${node.value}, %rax\n' case 'string_literal': label = compile_data.add_string_literal(node.value) compile_data.code += f' movq ${label}, %rax\n' case 'date_literal': compile_data.code += f' movq ${timegm(node.value.timetuple())}, %rax\n' case 'do_until': label_loop = compile_data.get_label() compile_data.insert_label(label_loop) # compile statements for statement in node.children_statements: compile_ast(statement, compile_data) # compile condition compile_ast(node.child_condition, compile_data) compile_data.code += f' testq %rax, %rax\n' compile_data.code += f' jz {label_loop}\n' case 'do_unless' | 'unless_expression': label_true = compile_data.get_label() label_done = compile_data.get_label() # compile condition compile_ast(node.child_condition, compile_data) compile_data.code += f' testq %rax, %rax\n' compile_data.code += f' jnz {label_true}\n' # compile false statements if node.nodetype == 'unless_expression': compile_ast(node.child_expression_false, compile_data) elif node.nodetype == 'do_unless': for statement in node.children_statements_false: compile_ast(statement, compile_data) else: assert False compile_data.code += f' jmp {label_done}\n' # compile true statements compile_data.insert_label(label_true) if node.nodetype == 'unless_expression': compile_ast(node.child_expression_true, compile_data) elif node.nodetype == 'do_unless': for statement in node.children_statements_true: compile_ast(statement, compile_data) else: assert False # add label for done compile_data.insert_label(label_done) case 'print': for i, item in enumerate(node.children_items): assert item.type in ['int', 'string', 'date'] compile_ast(item, compile_data) compile_print_literal(item.type, compile_data) # Print space if there are more items if i < len(node.children_items) - 1: compile_data.code += f' movl $\' \', %edi\n' compile_data.code += f' call putchar\n' # Print newline compile_data.code += f' movl $\'\\n\', %edi\n' compile_data.code += f' call putchar\n' case _: print_todo(f'Compile type \'{node.nodetype}\'', node) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-d', '--debug', action='store_true', help='debug?') group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--who', action='store_true', help='print out student IDs and NAMEs of authors') group.add_argument('-f', '--file', help='filename to process') args = parser.parse_args() if args.who: print('Author') print(' Student ID: 150189237') print(' Name: Oskari Alaranta') else: ast = syntax_check_file(args.file, args.debug) sem_data = SemData() semantic_check(ast, sem_data) #tree_print.treeprint(ast, 'unicode') compile_data = CompileData(sem_data) compile_ast(ast, compile_data) print(compile_data.get_full_code())