diff --git a/04_semantics_and_running/main.py b/04_semantics_and_running/main.py index f2851e3..067e81c 100644 --- a/04_semantics_and_running/main.py +++ b/04_semantics_and_running/main.py @@ -324,6 +324,58 @@ class CompileData: return f'(.globals + {offset})' assert False + def optimize_assembly(self) -> str: + for name, instructions in self.callables.items(): + if name.startswith('__builtin_'): + continue + + changed = True + while changed: + changed = False + + i = 0 + # Remove redundant movq instructions + while i < len(instructions): + if instructions[i].opcode != 'movq' or instructions[i].operands[0] != instructions[i].operands[1]: + i += 1 + continue + instructions.pop(i) + changed = True + + i = 0 + # Optimize movq to register followed by pushq register + while i < len(instructions): + if instructions[i].opcode != 'movq' or instructions[i + 1].opcode != 'pushq': + i += 1 + continue + # push of 64 bit immediate is not possible + if instructions[i].operands[0][0] == '$' and int(instructions[i].operands[0][1:]) > 0xFFFFFFFF: + continue + instructions[i] = Instruction('pushq', [instructions[i].operands[0]]) + instructions.pop(i + 1) + i -= 1 + changed = True + + i = 0 + # Optimize movq to rax followed by movq from rax + while i < len(instructions): + if instructions[i].opcode != 'movq' or instructions[i + 1].opcode != 'movq': + i += 1 + continue + if instructions[i].operands[1] != '%rax' or instructions[i + 1].operands[0] != '%rax': + i += 1 + continue + if instructions[i].operands[0][0] not in ['$', '%'] and instructions[i + 1].operands[1][0] != '%': + i += 1 + continue + # move of 64 bit immediate to memory is not possible + if instructions[i].operands[0][0] == '$' and instructions[i + 1].operands[0] != '%' and int(instructions[i].operands[0][1:]) > 0xFFFFFFFF: + continue + instructions[i] = Instruction('movq', [instructions[i].operands[0], instructions[i + 1].operands[1]]) + instructions.pop(i + 1) + i -= 1 + changed = True + def add_builtin_functions(self) -> None: today = [] today.append(Instruction('xorq', ['%rdi', '%rdi'])) @@ -710,6 +762,7 @@ if __name__ == '__main__': compile_data = CompileData(sem_data) compile_ast(ast, compile_data) + compile_data.optimize_assembly() assembly = compile_data.get_full_code() if args.assembly is not None: