Implement simple assembly transpiler for phase 4
The code produced is not optimized at all. This fills all other aspects of phase4 except for date attribute access.
This commit is contained in:
parent
53e945fbe5
commit
9238eae5ee
|
@ -1,6 +1,7 @@
|
|||
#!/bin/env python3
|
||||
|
||||
import argparse
|
||||
from calendar import timegm
|
||||
from copy import deepcopy
|
||||
from datetime import date, timedelta
|
||||
import tree_print
|
||||
|
@ -10,6 +11,7 @@ class SemData:
|
|||
def __init__(self):
|
||||
self.scope = None
|
||||
self.root = None
|
||||
self.callables = {}
|
||||
self.global_symbol_table = {}
|
||||
self.local_symbol_table = {}
|
||||
|
||||
|
@ -30,9 +32,9 @@ def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
|
|||
# since they can be called before they are defined
|
||||
for child in node.children_definitions:
|
||||
if child.nodetype in ['function_definition', 'procedure_definition']:
|
||||
if child.value in sem_data.global_symbol_table:
|
||||
if child.value in sem_data.callables:
|
||||
semantic_error(f'Redefinition of {child.nodetype.split("_")[0]} \'{child.value}\'', child)
|
||||
sem_data.global_symbol_table[child.value] = child
|
||||
sem_data.callables[child.value] = child
|
||||
|
||||
# Then do the actual semantic checking
|
||||
for child in node.children_definitions:
|
||||
|
@ -60,7 +62,7 @@ def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
|
|||
case 'function_definition' | 'procedure_definition':
|
||||
# Function and procedures are added to global symbol table
|
||||
# as the first step, so they can be called before they are defined
|
||||
assert node.value in sem_data.global_symbol_table
|
||||
assert node.value in sem_data.callables
|
||||
|
||||
# Local symbols table should be empty while doing checking,
|
||||
# since functions and procedures can only be defined in global scope
|
||||
|
@ -159,7 +161,7 @@ def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
|
|||
if lhs.type == 'int' and rhs.type == 'int':
|
||||
node.type = 'int'
|
||||
return node
|
||||
elif node.value in ['<', '>', '=']:
|
||||
elif node.value in ['<', '=']:
|
||||
if lhs.type == rhs.type:
|
||||
node.type = 'bool'
|
||||
return node
|
||||
|
@ -186,9 +188,9 @@ def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
|
|||
return node
|
||||
|
||||
# Check if function/procedure is defined
|
||||
if node.value not in sem_data.global_symbol_table:
|
||||
if node.value not in sem_data.callables:
|
||||
semantic_error(f'{node.nodetype.split("_")[0]} \'{node.value}\' not defined', node)
|
||||
func = sem_data.global_symbol_table[node.value]
|
||||
func = sem_data.callables[node.value]
|
||||
|
||||
# Check if arguments match (count and types)
|
||||
if len(node.children_arguments) != len(func.children_formals):
|
||||
|
@ -273,10 +275,334 @@ def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
|
|||
case _:
|
||||
print_todo(f'Semantic check type \'{node.nodetype}\'', node)
|
||||
|
||||
def execute_ast(node: ASTnode, sem_data: SemData) -> None | ASTnode:
|
||||
class CompileData:
|
||||
def __init__(self, sem_data: SemData):
|
||||
self.sem_data = sem_data
|
||||
self.string_literals = []
|
||||
self.label_counter = 0
|
||||
self.callables = {}
|
||||
self.scope = None
|
||||
self.code = ''
|
||||
|
||||
self.callables['__builtin_today'] = '''\
|
||||
pushq %rbp
|
||||
movq %rsp, %rbp
|
||||
xorq %rdi, %rdi
|
||||
call time
|
||||
movq %rax, %rdi
|
||||
movq $86400, %rcx
|
||||
xorq %rdx, %rdx
|
||||
divq %rcx
|
||||
movq %rdi, %rax
|
||||
subq %rdx, %rax
|
||||
popq %rbp
|
||||
ret
|
||||
'''
|
||||
|
||||
self.callables['__builtin_print_date'] = '''\
|
||||
pushq %rbp
|
||||
movq %rsp, %rbp
|
||||
subq $16, %rsp
|
||||
movq %rdi, 0(%rsp)
|
||||
leaq 0(%rsp), %rdi
|
||||
call gmtime
|
||||
movq $date_buffer, %rdi
|
||||
movq $(date_buffer_end - date_buffer), %rsi
|
||||
movq $date_format, %rdx
|
||||
movq %rax, %rcx
|
||||
call strftime
|
||||
movq $str_format, %rdi
|
||||
movq $date_buffer, %rsi
|
||||
call printf
|
||||
leave
|
||||
ret
|
||||
'''
|
||||
|
||||
def get_label(self) -> str:
|
||||
self.label_counter += 1
|
||||
return f'.L{self.label_counter - 1}'
|
||||
|
||||
def insert_label(self, label) -> None:
|
||||
self.code += f'{label}:\n'
|
||||
|
||||
def add_string_literal(self, value: str) -> str:
|
||||
for index, string in enumerate(self.string_literals):
|
||||
if string == value:
|
||||
return f'S{index}'
|
||||
self.string_literals.append(value)
|
||||
return f'S{len(self.string_literals) - 1}'
|
||||
|
||||
def symbol_address(self, symbol: str) -> str:
|
||||
if self.scope is not None:
|
||||
for index, formal in enumerate(self.scope.children_formals):
|
||||
if formal.value == symbol:
|
||||
offset = 8 * index + 16
|
||||
return f'{offset}(%rbp)'
|
||||
for index, variable in enumerate(self.scope.children_variable_definitions):
|
||||
if variable.value == symbol:
|
||||
offset = 8 * index + 8
|
||||
return f'-{offset}(%rbp)'
|
||||
if symbol in self.sem_data.global_symbol_table:
|
||||
offset = 8 * list(self.sem_data.global_symbol_table.keys()).index(symbol)
|
||||
return f'(globals + {offset})'
|
||||
assert False
|
||||
|
||||
def get_full_code(self) -> str:
|
||||
# Data section with string literals
|
||||
prefix = '.section .data\n'
|
||||
prefix += 'int_format: .asciz "%lld"\n'
|
||||
prefix += 'str_format: .asciz "%s"\n'
|
||||
prefix += 'date_format: .asciz "%Y-%m-%d"\n'
|
||||
for index, string in enumerate(self.string_literals):
|
||||
prefix += f'S{index}: .asciz "{string}"\n'
|
||||
prefix += '\n'
|
||||
|
||||
# BSS section for uninitialized data
|
||||
prefix += '.section .bss\n'
|
||||
prefix += 'date_buffer:\n'
|
||||
prefix += ' .skip 128\n'
|
||||
prefix += 'date_buffer_end:\n'
|
||||
if len(self.sem_data.global_symbol_table) != 0:
|
||||
prefix += 'globals:\n'
|
||||
prefix += f' .skip {len(sem_data.global_symbol_table) * 8}\n'
|
||||
prefix += '\n'
|
||||
|
||||
# Text section with code
|
||||
prefix += '.section .text\n'
|
||||
prefix += '\n'
|
||||
|
||||
# Add function and procedure definitions
|
||||
for name, code in self.callables.items():
|
||||
prefix += name + ':\n'
|
||||
prefix += code
|
||||
prefix += '\n'
|
||||
|
||||
prefix += '.global main\n'
|
||||
prefix += 'main:\n'
|
||||
prefix += ' pushq %rbp\n'
|
||||
prefix += ' movq %rsp, %rbp\n'
|
||||
|
||||
postfix = ' xorq %rax, %rax\n'
|
||||
postfix += ' leave\n'
|
||||
postfix += ' ret\n'
|
||||
|
||||
return prefix + self.code + postfix
|
||||
|
||||
def compile_print_literal(print_type: str, compile_data: CompileData) -> None:
|
||||
if print_type == 'int':
|
||||
compile_data.code += f' movq $int_format, %rdi\n'
|
||||
compile_data.code += f' movq %rax, %rsi\n'
|
||||
compile_data.code += f' call printf\n'
|
||||
elif print_type == 'string':
|
||||
compile_data.code += f' movq $str_format, %rdi\n'
|
||||
compile_data.code += f' movq %rax, %rsi\n'
|
||||
compile_data.code += f' call printf\n'
|
||||
elif print_type == 'date':
|
||||
compile_data.code += f' movq %rax, %rdi\n'
|
||||
compile_data.code += f' call __builtin_print_date\n'
|
||||
else:
|
||||
assert False
|
||||
|
||||
def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
|
||||
match node.nodetype:
|
||||
case 'program':
|
||||
# Compile function and procedure definitions
|
||||
for definition in node.children_definitions:
|
||||
if definition.nodetype not in ['function_definition', 'procedure_definition']:
|
||||
continue
|
||||
assert len(compile_data.code) == 0
|
||||
assert compile_data.scope is None
|
||||
compile_data.scope = definition
|
||||
# initialize stack frame
|
||||
compile_data.code += ' pushq %rbp\n'
|
||||
compile_data.code += ' movq %rsp, %rbp\n'
|
||||
# initialize local variables
|
||||
stack_size = 8 * len(definition.children_variable_definitions)
|
||||
if stack_size % 16 != 0:
|
||||
stack_size += 8
|
||||
if stack_size != 0:
|
||||
compile_data.code += f' subq ${stack_size}, %rsp\n'
|
||||
for variable in definition.children_variable_definitions:
|
||||
address = compile_data.symbol_address(variable.value)
|
||||
compile_ast(variable.child_expression, compile_data)
|
||||
compile_data.code += f' movq %rax, {address}\n'
|
||||
# compile statements
|
||||
if definition.nodetype == 'function_definition':
|
||||
compile_ast(definition.child_expression, compile_data)
|
||||
elif definition.nodetype == 'procedure_definition':
|
||||
for statement in definition.children_statements:
|
||||
compile_ast(statement, compile_data)
|
||||
else: assert False
|
||||
# return from procedure
|
||||
compile_data.code += f' leave\n'
|
||||
compile_data.code += f' ret\n'
|
||||
compile_data.callables[definition.value] = compile_data.code
|
||||
compile_data.code = ''
|
||||
compile_data.scope = None
|
||||
# Initialize global variables
|
||||
for index, (name, variable) in enumerate(compile_data.sem_data.global_symbol_table.items()):
|
||||
offset = 8 * index
|
||||
compile_ast(variable, compile_data)
|
||||
compile_data.code += f' movq %rax, (globals + {offset})\n'
|
||||
# Compile program statements
|
||||
for statement in node.children_statements:
|
||||
compile_ast(statement, compile_data)
|
||||
case 'variable_definition' | 'function_definition' | 'procedure_definition':
|
||||
assert False
|
||||
case 'identifier':
|
||||
address = compile_data.symbol_address(node.value)
|
||||
compile_data.code += f' movq {address}, %rax\n'
|
||||
case 'assignment':
|
||||
address = compile_data.symbol_address(node.child_lhs.value)
|
||||
compile_ast(node.child_rhs, compile_data)
|
||||
compile_data.code += f' movq %rax, {address}\n'
|
||||
case 'binary_op':
|
||||
assert node.value in ['+', '-', '*', '/', '<', '=']
|
||||
|
||||
if node.value in ['*', '/']:
|
||||
assert node.child_lhs.type == 'int'
|
||||
else:
|
||||
assert node.child_lhs.type in ['int', 'date']
|
||||
|
||||
if node.value == '-' and node.child_lhs.type == 'date':
|
||||
assert node.child_rhs.type in ['int', 'date']
|
||||
else:
|
||||
assert node.child_rhs.type == 'int'
|
||||
|
||||
# calculate LHS and store it on stack
|
||||
compile_data.code += f' subq $16, %rsp\n'
|
||||
compile_ast(node.child_lhs, compile_data)
|
||||
compile_data.code += f' movq %rax, 0(%rsp)\n'
|
||||
|
||||
# calculate RHS and store it in RCX
|
||||
compile_ast(node.child_rhs, compile_data)
|
||||
if node.child_lhs.type == 'date' and node.child_rhs.type == 'int':
|
||||
# multiply RHS by number of seconds in a day so we can perform arithmetic on dates
|
||||
compile_data.code += f' imulq $86400, %rax, %rcx\n'
|
||||
else:
|
||||
compile_data.code += f' movq %rax, %rcx\n'
|
||||
|
||||
# prepare registers, RAX contains LHS and RCX contains RHS
|
||||
# and restore restore stack
|
||||
compile_data.code += f' movq 0(%rsp), %rax\n'
|
||||
compile_data.code += f' addq $16, %rsp\n'
|
||||
|
||||
# perform operation
|
||||
if node.value == '+':
|
||||
compile_data.code += f' addq %rcx, %rax\n'
|
||||
elif node.value == '-':
|
||||
compile_data.code += f' subq %rcx, %rax\n'
|
||||
elif node.value == '*':
|
||||
compile_data.code += f' imulq %rcx, %rax\n'
|
||||
elif node.value == '/':
|
||||
compile_data.code += f' cqo\n'
|
||||
compile_data.code += f' idivq %rcx\n'
|
||||
elif node.value == '<':
|
||||
compile_data.code += f' cmpq %rcx, %rax\n'
|
||||
compile_data.code += f' setl %al\n'
|
||||
compile_data.code += f' movzbq %al, %rax\n'
|
||||
elif node.value == '=':
|
||||
compile_data.code += f' cmpq %rcx, %rax\n'
|
||||
compile_data.code += f' sete %al\n'
|
||||
compile_data.code += f' movzbq %al, %rax\n'
|
||||
else: assert False
|
||||
|
||||
# if both operands are dates, divide result by number of seconds in a day
|
||||
if node.child_lhs.type == 'date' and node.child_rhs.type == 'date':
|
||||
assert node.value == '-'
|
||||
compile_data.code += f' movq $86400, %rcx\n'
|
||||
compile_data.code += f' cqo\n'
|
||||
compile_data.code += f' idivq %rcx\n'
|
||||
case 'function_call' | 'procedure_call':
|
||||
if node.value == 'Today':
|
||||
compile_data.code += f' call __builtin_today\n'
|
||||
else:
|
||||
# align stack
|
||||
stack_needed = len(node.children_arguments) * 8
|
||||
if stack_needed % 16 != 0:
|
||||
stack_needed += 8
|
||||
if stack_needed != 0:
|
||||
compile_data.code += f' subq ${stack_needed}, %rsp\n'
|
||||
|
||||
# push arguments to the stack
|
||||
offset = 0
|
||||
for argument in node.children_arguments:
|
||||
compile_ast(argument, compile_data)
|
||||
compile_data.code += f' movq %rax, {offset}(%rsp)\n'
|
||||
offset += 8
|
||||
|
||||
# call function and restore stack
|
||||
compile_data.code += f' call {node.value}\n'
|
||||
if stack_needed != 0:
|
||||
compile_data.code += f' addq ${stack_needed}, %rsp\n'
|
||||
case 'return':
|
||||
compile_ast(node.child_expression, compile_data)
|
||||
compile_data.code += f' leave\n'
|
||||
compile_data.code += f' ret\n'
|
||||
case 'int_literal':
|
||||
compile_data.code += f' movq ${node.value}, %rax\n'
|
||||
case 'string_literal':
|
||||
label = compile_data.add_string_literal(node.value)
|
||||
compile_data.code += f' movq ${label}, %rax\n'
|
||||
case 'date_literal':
|
||||
compile_data.code += f' movq ${timegm(node.value.timetuple())}, %rax\n'
|
||||
case 'do_until':
|
||||
label_loop = compile_data.get_label()
|
||||
compile_data.insert_label(label_loop)
|
||||
|
||||
# compile statements
|
||||
for statement in node.children_statements:
|
||||
compile_ast(statement, compile_data)
|
||||
|
||||
# compile condition
|
||||
compile_ast(node.child_condition, compile_data)
|
||||
compile_data.code += f' testq %rax, %rax\n'
|
||||
compile_data.code += f' jz {label_loop}\n'
|
||||
case 'do_unless' | 'unless_expression':
|
||||
label_true = compile_data.get_label()
|
||||
label_done = compile_data.get_label()
|
||||
|
||||
# compile condition
|
||||
compile_ast(node.child_condition, compile_data)
|
||||
compile_data.code += f' testq %rax, %rax\n'
|
||||
compile_data.code += f' jnz {label_true}\n'
|
||||
|
||||
# compile false statements
|
||||
if node.nodetype == 'unless_expression':
|
||||
compile_ast(node.child_expression_false, compile_data)
|
||||
elif node.nodetype == 'do_unless':
|
||||
for statement in node.children_statements_false:
|
||||
compile_ast(statement, compile_data)
|
||||
else: assert False
|
||||
compile_data.code += f' jmp {label_done}\n'
|
||||
|
||||
# compile true statements
|
||||
compile_data.insert_label(label_true)
|
||||
if node.nodetype == 'unless_expression':
|
||||
compile_ast(node.child_expression_true, compile_data)
|
||||
elif node.nodetype == 'do_unless':
|
||||
for statement in node.children_statements_true:
|
||||
compile_ast(statement, compile_data)
|
||||
else: assert False
|
||||
|
||||
# add label for done
|
||||
compile_data.insert_label(label_done)
|
||||
case 'print':
|
||||
for i, item in enumerate(node.children_items):
|
||||
assert item.type in ['int', 'string', 'date']
|
||||
compile_ast(item, compile_data)
|
||||
compile_print_literal(item.type, compile_data)
|
||||
|
||||
# Print space if there are more items
|
||||
if i < len(node.children_items) - 1:
|
||||
compile_data.code += f' movl $\' \', %edi\n'
|
||||
compile_data.code += f' call putchar\n'
|
||||
|
||||
# Print newline
|
||||
compile_data.code += f' movl $\'\\n\', %edi\n'
|
||||
compile_data.code += f' call putchar\n'
|
||||
case _:
|
||||
print_todo(f'Execute type \'{node.nodetype}\'', node)
|
||||
print_todo(f'Compile type \'{node.nodetype}\'', node)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -293,8 +619,11 @@ if __name__ == '__main__':
|
|||
print(' Name: Oskari Alaranta')
|
||||
else:
|
||||
ast = syntax_check_file(args.file, args.debug)
|
||||
#tree_print.treeprint(ast, 'unicode')
|
||||
|
||||
sem_data = SemData()
|
||||
semantic_check(ast, sem_data)
|
||||
execute_ast(ast, sem_data)
|
||||
#tree_print.treeprint(ast, 'unicode')
|
||||
|
||||
compile_data = CompileData(sem_data)
|
||||
compile_ast(ast, compile_data)
|
||||
print(compile_data.get_full_code())
|
||||
|
|
Loading…
Reference in New Issue