Fix bug with temporary registers and optimize binary_op code gen

This commit is contained in:
Bananymous 2024-04-29 21:03:39 +03:00
parent 69ae83efcb
commit 351634e512
1 changed files with 43 additions and 23 deletions

View File

@ -448,16 +448,36 @@ class CompileData:
# Add function and procedure definitions
for name, code in self.callables.items():
if name == 'main':
code_str += '.global main\n'
saved_registers = []
else:
caller_saved = ['%r8', '%r9', '%r10', '%r11']
saved_registers = set()
for instruction in code:
for reg in caller_saved:
if reg in instruction.operands:
saved_registers.add(reg)
break
saved_registers = list(saved_registers)
code_str += name + ':\n'
code_str += ' pushq %rbp\n'
code_str += ' movq %rsp, %rbp\n'
for reg in saved_registers:
code_str += f' pushq {reg}\n'
if len(saved_registers) % 2 != 0:
code_str += ' subq $8, %rsp\n'
for instruction in code:
if instruction.opcode == '<label>':
code_str += f'{instruction.operands[0]}:\n'
else:
code_str += f' {instruction}\n'
if len(saved_registers) % 2 != 0:
code_str += ' addq $8, %rsp\n'
for reg in reversed(saved_registers):
code_str += f' popq {reg}\n'
code_str += ' leave\n'
code_str += ' ret\n'
code_str += '\n'
@ -555,30 +575,30 @@ def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
compile_data.code = old_code
# check if we can use temporary registers instead of stack
temp_registers = ['%rcx', '%r8', '%r9', '%r10', '%r11']
temp_reg = None
for reg in temp_registers:
usable_registers = ['%r8', '%r9', '%r10', '%r11']
register = '%rcx'
for reg in usable_registers:
valid = True
for instruction in lhs_code:
if reg in instruction.operands:
valid = False
break
if valid:
temp_reg = reg
register = reg
break
# check if lhs uses call, this determines whether we need to align stack
lhs_call = False
align_stack = False
for instruction in lhs_code:
if instruction.opcode == 'call':
lhs_call = True
align_stack = True
break
# Add code for RHS calculation
compile_data.code += rhs_code
if temp_reg is not None:
compile_data.code.append(Instruction('movq', ['%rax', temp_reg]))
elif not lhs_call:
if register != '%rcx':
compile_data.code.append(Instruction('movq', ['%rax', register]))
elif not align_stack:
compile_data.code.append(Instruction('pushq', ['%rax']))
else:
compile_data.code.append(Instruction('subq', ['$16', '%rsp']))
@ -586,34 +606,34 @@ def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
# Add code for LHS calculation
compile_data.code += lhs_code
if temp_reg is not None:
compile_data.code.append(Instruction('movq', [temp_reg, '%rcx']))
elif not lhs_call:
compile_data.code.append(Instruction('popq', ['%rcx']))
if register != '%rcx':
pass
elif not align_stack:
compile_data.code.append(Instruction('popq', [register]))
else:
compile_data.code.append(Instruction('movq', ['0(%rsp)', '%rcx']))
compile_data.code.append(Instruction('movq', ['0(%rsp)', register]))
compile_data.code.append(Instruction('addq', ['$16', '%rsp']))
# If we are adding or subtracting dates with integers, multiply the integer by number of seconds in a day
if node.child_lhs.type == 'date' and node.child_rhs.type == 'int':
compile_data.code.append(Instruction('imulq', ['$86400', '%rcx']))
compile_data.code.append(Instruction('imulq', ['$86400', register]))
# perform operation
if node.value == '+':
compile_data.code.append(Instruction('addq', ['%rcx', '%rax']))
compile_data.code.append(Instruction('addq', [register, '%rax']))
elif node.value == '-':
compile_data.code.append(Instruction('subq', ['%rcx', '%rax']))
compile_data.code.append(Instruction('subq', [register, '%rax']))
elif node.value == '*':
compile_data.code.append(Instruction('imulq', ['%rcx', '%rax']))
compile_data.code.append(Instruction('imulq', [register, '%rax']))
elif node.value == '/':
compile_data.code.append(Instruction('cqo'))
compile_data.code.append(Instruction('idivq', ['%rcx']))
compile_data.code.append(Instruction('idivq', [register]))
elif node.value == '<':
compile_data.code.append(Instruction('cmpq', ['%rcx', '%rax']))
compile_data.code.append(Instruction('cmpq', [register, '%rax']))
compile_data.code.append(Instruction('setl', ['%al']))
compile_data.code.append(Instruction('movzbq', ['%al', '%rax']))
elif node.value == '=':
compile_data.code.append(Instruction('cmpq', ['%rcx', '%rax']))
compile_data.code.append(Instruction('cmpq', [register, '%rax']))
compile_data.code.append(Instruction('sete', ['%al']))
compile_data.code.append(Instruction('movzbq', ['%al', '%rax']))
else: assert False
@ -621,9 +641,9 @@ def compile_ast(node: ASTnode, compile_data: CompileData) -> None:
# if both operands are dates, divide result by number of seconds in a day
if node.child_lhs.type == 'date' and node.child_rhs.type == 'date':
assert node.value == '-'
compile_data.code.append(Instruction('movq', ['$86400', '%rcx']))
compile_data.code.append(Instruction('movq', ['$86400', register]))
compile_data.code.append(Instruction('cqo'))
compile_data.code.append(Instruction('idivq', ['%rcx']))
compile_data.code.append(Instruction('idivq', [register]))
case 'function_call' | 'procedure_call':
if node.value == 'Today':
compile_data.code.append(Instruction('call', ['__builtin_today']))