Cleanup code and add arguments for compiling code

This commit is contained in:
Bananymous 2024-04-28 03:02:25 +03:00
parent 5c5942629d
commit 8656d74139
1 changed files with 54 additions and 34 deletions

View File

@ -4,6 +4,7 @@ import argparse
from calendar import timegm from calendar import timegm
from copy import deepcopy from copy import deepcopy
from datetime import date, timedelta from datetime import date, timedelta
import subprocess
import tree_print import tree_print
from build_ast import ASTnode, syntax_check_file from build_ast import ASTnode, syntax_check_file
@ -278,6 +279,7 @@ def semantic_check(node: ASTnode, sem_data: SemData) -> None | ASTnode:
class CompileData: class CompileData:
def __init__(self, sem_data: SemData): def __init__(self, sem_data: SemData):
self.sem_data = sem_data self.sem_data = sem_data
self.date_buffer_size = 128
self.string_literals = [] self.string_literals = []
self.label_counter = 0 self.label_counter = 0
self.callables = {} self.callables = {}
@ -299,26 +301,26 @@ class CompileData:
ret ret
''' '''
self.callables['__builtin_print_date'] = '''\ self.callables['__builtin_print_date'] = f'''\
pushq %rbp pushq %rbp
movq %rsp, %rbp movq %rsp, %rbp
subq $16, %rsp subq $16, %rsp
movq %rdi, 0(%rsp) movq %rdi, 0(%rsp)
leaq 0(%rsp), %rdi leaq 0(%rsp), %rdi
call localtime call localtime
movq $date_buffer, %rdi movq $.date_buffer, %rdi
movq $(date_buffer_end - date_buffer), %rsi movq ${self.date_buffer_size}, %rsi
movq $date_format, %rdx movq $.date_format, %rdx
movq %rax, %rcx movq %rax, %rcx
call strftime call strftime
movq $str_format, %rdi movq $.str_format, %rdi
movq $date_buffer, %rsi movq $.date_buffer, %rsi
call printf call printf
leave leave
ret ret
''' '''
self.callables['__builtin_get_day_attr'] = '''\ self.callables['__builtin_get_day_attr'] = f'''\
pushq %rbp pushq %rbp
movq %rsp, %rbp movq %rsp, %rbp
subq $16, %rsp subq $16, %rsp
@ -326,12 +328,12 @@ class CompileData:
movq %rsi, 8(%rsp) movq %rsi, 8(%rsp)
leaq 0(%rsp), %rdi leaq 0(%rsp), %rdi
call localtime call localtime
movq $date_buffer, %rdi movq $.date_buffer, %rdi
movq $(date_buffer_end - date_buffer), %rsi movq ${self.date_buffer_size}, %rsi
movq 8(%rsp), %rdx movq 8(%rsp), %rdx
movq %rax, %rcx movq %rax, %rcx
call strftime call strftime
movq $date_buffer, %rdi movq $.date_buffer, %rdi
call atoi call atoi
leave leave
ret ret
@ -363,31 +365,30 @@ class CompileData:
return f'-{offset}(%rbp)' return f'-{offset}(%rbp)'
if symbol in self.sem_data.global_symbol_table: if symbol in self.sem_data.global_symbol_table:
offset = 8 * list(self.sem_data.global_symbol_table.keys()).index(symbol) offset = 8 * list(self.sem_data.global_symbol_table.keys()).index(symbol)
return f'(globals + {offset})' return f'(.globals + {offset})'
assert False assert False
def get_full_code(self) -> str: def get_full_code(self) -> str:
# Data section with string literals # Data section with string literals
prefix = '.section .data\n' prefix = '.section .data\n'
prefix += 'int_format: .asciz "%lld"\n' prefix += '.int_format: .asciz "%lld"\n'
prefix += 'str_format: .asciz "%s"\n' prefix += '.str_format: .asciz "%s"\n'
prefix += 'date_format: .asciz "%Y-%m-%d"\n' prefix += '.date_format: .asciz "%Y-%m-%d"\n'
prefix += 'day_format: .asciz "%d"\n' prefix += '.day_format: .asciz "%d"\n'
prefix += 'month_format: .asciz "%m"\n' prefix += '.month_format: .asciz "%m"\n'
prefix += 'year_format: .asciz "%Y"\n' prefix += '.year_format: .asciz "%Y"\n'
prefix += 'weekday_format: .asciz "%u"\n' prefix += '.weekday_format: .asciz "%u"\n'
prefix += 'weeknum_format: .asciz "%W"\n' prefix += '.weeknum_format: .asciz "%W"\n'
for index, string in enumerate(self.string_literals): for index, string in enumerate(self.string_literals):
prefix += f'S{index}: .asciz "{string}"\n' prefix += f'S{index}: .asciz "{string}"\n'
prefix += '\n' prefix += '\n'
# BSS section for uninitialized data # BSS section for uninitialized data
prefix += '.section .bss\n' prefix += f'.section .bss\n'
prefix += 'date_buffer:\n' prefix += f'.date_buffer:\n'
prefix += ' .skip 128\n' prefix += f' .skip {self.date_buffer_size}\n'
prefix += 'date_buffer_end:\n'
if len(self.sem_data.global_symbol_table) != 0: if len(self.sem_data.global_symbol_table) != 0:
prefix += 'globals:\n' prefix += '.globals:\n'
prefix += f' .skip {len(sem_data.global_symbol_table) * 8}\n' prefix += f' .skip {len(sem_data.global_symbol_table) * 8}\n'
prefix += '\n' prefix += '\n'
@ -414,11 +415,11 @@ class CompileData:
def compile_print_literal(print_type: str, compile_data: CompileData) -> None: def compile_print_literal(print_type: str, compile_data: CompileData) -> None:
if print_type == 'int': if print_type == 'int':
compile_data.code += f' movq $int_format, %rdi\n' compile_data.code += f' movq $.int_format, %rdi\n'
compile_data.code += f' movq %rax, %rsi\n' compile_data.code += f' movq %rax, %rsi\n'
compile_data.code += f' call printf\n' compile_data.code += f' call printf\n'
elif print_type == 'string': elif print_type == 'string':
compile_data.code += f' movq $str_format, %rdi\n' compile_data.code += f' movq $.str_format, %rdi\n'
compile_data.code += f' movq %rax, %rsi\n' compile_data.code += f' movq %rax, %rsi\n'
compile_data.code += f' call printf\n' compile_data.code += f' call printf\n'
elif print_type == 'date': elif print_type == 'date':
@ -465,9 +466,9 @@ def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
compile_data.scope = None compile_data.scope = None
# Initialize global variables # Initialize global variables
for index, (name, variable) in enumerate(compile_data.sem_data.global_symbol_table.items()): for index, (name, variable) in enumerate(compile_data.sem_data.global_symbol_table.items()):
offset = 8 * index address = compile_data.symbol_address(name)
compile_ast(variable, compile_data) compile_ast(variable, compile_data)
compile_data.code += f' movq %rax, (globals + {offset})\n' compile_data.code += f' movq %rax, {address}\n'
# Compile program statements # Compile program statements
for statement in node.children_statements: for statement in node.children_statements:
compile_ast(statement, compile_data) compile_ast(statement, compile_data)
@ -477,9 +478,13 @@ def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
address = compile_data.symbol_address(node.value) address = compile_data.symbol_address(node.value)
compile_data.code += f' movq {address}, %rax\n' compile_data.code += f' movq {address}, %rax\n'
case 'assignment': case 'assignment':
if node.child_lhs.nodetype == 'attribute_write':
print_todo('Attribute write', node)
elif node.child_lhs.nodetype == 'identifier':
address = compile_data.symbol_address(node.child_lhs.value) address = compile_data.symbol_address(node.child_lhs.value)
compile_ast(node.child_rhs, compile_data) compile_ast(node.child_rhs, compile_data)
compile_data.code += f' movq %rax, {address}\n' compile_data.code += f' movq %rax, {address}\n'
else: assert False
case 'binary_op': case 'binary_op':
assert node.value in ['+', '-', '*', '/', '<', '='] assert node.value in ['+', '-', '*', '/', '<', '=']
@ -573,7 +578,7 @@ def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
case 'attribute_read': case 'attribute_read':
compile_ast(node.child_identifier, compile_data) compile_ast(node.child_identifier, compile_data)
compile_data.code += f' movq %rax, %rdi\n' compile_data.code += f' movq %rax, %rdi\n'
compile_data.code += f' movq ${node.child_attribute.value}_format, %rsi\n' compile_data.code += f' movq $.{node.child_attribute.value}_format, %rsi\n'
compile_data.code += f' call __builtin_get_day_attr\n' compile_data.code += f' call __builtin_get_day_attr\n'
case 'do_until': case 'do_until':
label_loop = compile_data.get_label() label_loop = compile_data.get_label()
@ -639,6 +644,8 @@ if __name__ == '__main__':
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--who', action='store_true', help='print out student IDs and NAMEs of authors') group.add_argument('--who', action='store_true', help='print out student IDs and NAMEs of authors')
group.add_argument('-f', '--file', help='filename to process') group.add_argument('-f', '--file', help='filename to process')
parser.add_argument('-o', '--compile', help='output filename for compiled code. if not given, no compilation is done')
parser.add_argument('-a', '--assembly', help='output filename for generated assembly code')
args = parser.parse_args() args = parser.parse_args()
@ -651,8 +658,21 @@ if __name__ == '__main__':
sem_data = SemData() sem_data = SemData()
semantic_check(ast, sem_data) semantic_check(ast, sem_data)
#tree_print.treeprint(ast, 'unicode')
if args.debug:
tree_print.treeprint(ast, 'unicode')
compile_data = CompileData(sem_data) compile_data = CompileData(sem_data)
compile_ast(ast, compile_data) compile_ast(ast, compile_data)
print(compile_data.get_full_code())
assembly = compile_data.get_full_code()
if args.assembly is None:
if args.compile is None:
print(assembly)
else:
with open(args.assembly, 'w', encoding='utf-8') as file:
file.write(assembly)
if args.compile is not None:
subprocess.run(['gcc', '-x', 'assembler', '-o', args.compile, '-static', '-'], input=assembly, encoding='utf-8')