diff --git a/asm/amd64/amd64.go b/asm/amd64/amd64.go index 7849e2c..6881e9f 100644 --- a/asm/amd64/amd64.go +++ b/asm/amd64/amd64.go @@ -6,7 +6,16 @@ import ( ) type Program struct { - Functions []Function + Functions []Function + MainFunction *Function +} + +func (p *Program) executableAsmHeader() string { + + if p.MainFunction.HasReturnValue { + return executableAsmHeader + } + return executableAsmHeaderNoReturnValue } // This calls the main function and uses it's return value to exit @@ -19,9 +28,18 @@ const executableAsmHeader = "format ELF64 executable\n" + " mov rax, 60\n" + " syscall\n" +const executableAsmHeaderNoReturnValue = "format ELF64 executable\n" + + "segment readable executable\n" + + "entry _start\n" + + "_start:\n" + + " call main\n" + + " mov rdi, 0\n" + + " mov rax, 60\n" + + " syscall\n" + func (p *Program) Emit() string { var builder strings.Builder - builder.WriteString(executableAsmHeader) + builder.WriteString(p.executableAsmHeader()) for _, function := range p.Functions { builder.WriteString(function.Emit()) @@ -32,9 +50,10 @@ func (p *Program) Emit() string { } type Function struct { - StackOffset int64 - Name string - Instructions []Instruction + StackOffset int64 + Name string + HasReturnValue bool + Instructions []Instruction } func (f *Function) Emit() string { diff --git a/asm/amd64/codegen.go b/asm/amd64/codegen.go index 11cfd63..117733c 100644 --- a/asm/amd64/codegen.go +++ b/asm/amd64/codegen.go @@ -32,6 +32,12 @@ func CgProgram(prog *ttir.Program) *Program { newProgram = replacePseudo(newProgram) newProgram = instructionFixup(newProgram) + for i, f := range newProgram.Functions { + if f.Name == "main" { + newProgram.MainFunction = &newProgram.Functions[i] + } + } + return &newProgram } @@ -43,23 +49,28 @@ func cgFunction(f ttir.Function) Function { } return Function{ - Name: f.Name, - Instructions: newInstructions, + Name: f.Name, + Instructions: newInstructions, + HasReturnValue: f.HasReturnValue, } } func cgInstruction(i ttir.Instruction) []Instruction { switch i := i.(type) { case *ttir.Ret: - return []Instruction{ - &SimpleInstruction{ - Opcode: Mov, - Lhs: AX, - Rhs: toAsmOperand(i.Op), - }, - &SimpleInstruction{ - Opcode: Ret, - }, + if i.Op != nil { + return []Instruction{ + &SimpleInstruction{ + Opcode: Mov, + Lhs: AX, + Rhs: toAsmOperand(i.Op), + }, + &SimpleInstruction{ + Opcode: Ret, + }, + } + } else { + return []Instruction{&SimpleInstruction{Opcode: Ret}} } case *ttir.Binary: return cgBinary(i) @@ -149,7 +160,7 @@ func rpFunction(f Function) Function { newInstructions = append(newInstructions, rpInstruction(i, r)) } - return Function{Instructions: newInstructions, Name: f.Name, StackOffset: r.currentOffset} + return Function{Instructions: newInstructions, Name: f.Name, StackOffset: r.currentOffset, HasReturnValue: f.HasReturnValue} } func rpInstruction(i Instruction, r *replacePseudoPass) Instruction { @@ -209,7 +220,7 @@ func fixupFunction(f Function) Function { newInstructions = append(newInstructions, fixupInstruction(i)...) } - return Function{Name: f.Name, Instructions: newInstructions, StackOffset: f.StackOffset} + return Function{Name: f.Name, Instructions: newInstructions, StackOffset: f.StackOffset, HasReturnValue: f.HasReturnValue} } func fixupInstruction(i Instruction) []Instruction { diff --git a/tast/tast.go b/tast/tast.go index 6cb2451..af615be 100644 --- a/tast/tast.go +++ b/tast/tast.go @@ -111,3 +111,32 @@ func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal } func (be *BinaryExpression) String() string { return fmt.Sprintf("(%s %s %s :> %s)", be.Lhs, be.Operator.SymbolString(), be.Rhs, be.ResultType.Name()) } + +type BlockExpression struct { + Token token.Token // The '{' + Expressions []Expression + ReturnExpression Expression // A expression that does not end with a semicolon, there can only be one of those and it hast to be at the end + ReturnType types.Type +} + +func (be *BlockExpression) expressionNode() {} +func (be *BlockExpression) Type() types.Type { + return be.ReturnType +} +func (be *BlockExpression) TokenLiteral() string { return be.Token.Literal } +func (be *BlockExpression) String() string { + var builder strings.Builder + + builder.WriteString("({\n") + for _, expr := range be.Expressions { + builder.WriteString("\t") + builder.WriteString(expr.String()) + builder.WriteString(";\n") + } + if be.ReturnExpression != nil { + builder.WriteString(fmt.Sprintf("\t%s\n", be.ReturnExpression.String())) + } + builder.WriteString("})") + + return builder.String() +} diff --git a/ttir/emit.go b/ttir/emit.go index 09724c9..336353e 100644 --- a/ttir/emit.go +++ b/ttir/emit.go @@ -4,6 +4,7 @@ import ( "fmt" "robaertschi.xyz/robaertschi/tt/tast" + "robaertschi.xyz/robaertschi/tt/types" ) var uniqueId int64 @@ -31,8 +32,9 @@ func emitFunction(function *tast.FunctionDeclaration) *Function { value, instructions := emitExpression(function.Body) instructions = append(instructions, &Ret{Op: value}) return &Function{ - Name: function.Name, - Instructions: instructions, + Name: function.Name, + Instructions: instructions, + HasReturnValue: !function.ReturnType.IsSameType(types.Unit), } } @@ -56,6 +58,22 @@ func emitExpression(expr tast.Expression) (Operand, []Instruction) { instructions = append(instructions, &Binary{Operator: expr.Operator, Lhs: lhsDst, Rhs: rhsDst, Dst: dst}) return dst, instructions } + case *tast.BlockExpression: + instructions := []Instruction{} + + for _, expr := range expr.Expressions { + _, insts := emitExpression(expr) + instructions = append(instructions, insts...) + } + + var value Operand + if expr.ReturnExpression != nil { + dst, insts := emitExpression(expr.ReturnExpression) + value = dst + instructions = append(instructions, insts...) + } + + return value, instructions } panic("unhandled tast.Expression case in ir emitter") } diff --git a/ttir/ttir.go b/ttir/ttir.go index 360baa7..b93356f 100644 --- a/ttir/ttir.go +++ b/ttir/ttir.go @@ -20,8 +20,9 @@ func (p *Program) String() string { } type Function struct { - Name string - Instructions []Instruction + Name string + Instructions []Instruction + HasReturnValue bool } func (f *Function) String() string { @@ -40,11 +41,16 @@ type Instruction interface { } type Ret struct { + // Nullable, if it does not return anything Op Operand } func (r *Ret) String() string { - return fmt.Sprintf("ret %s\n", r.Op) + if r.Op != nil { + return fmt.Sprintf("ret %s\n", r.Op) + } else { + return "ret\n" + } } func (r *Ret) instruction() {} diff --git a/typechecker/check.go b/typechecker/check.go index 2fba81d..13f29c4 100644 --- a/typechecker/check.go +++ b/typechecker/check.go @@ -85,6 +85,16 @@ func (c *Checker) checkExpression(expr tast.Expression) error { } return errors.Join(lhsErr, rhsErr, operandErr) + case *tast.BlockExpression: + errs := []error{} + + for _, expr := range expr.Expressions { + errs = append(errs, c.checkExpression(expr)) + } + if expr.ReturnExpression != nil { + errs = append(errs, c.checkExpression(expr.ReturnExpression)) + } + return errors.Join(errs...) } - return fmt.Errorf("unhandled expression in type checker") + return fmt.Errorf("unhandled expression %T in type checker", expr) } diff --git a/typechecker/infer.go b/typechecker/infer.go index edb9a7d..c1e61c5 100644 --- a/typechecker/infer.go +++ b/typechecker/infer.go @@ -60,6 +60,39 @@ func (c *Checker) inferExpression(expr ast.Expression) (tast.Expression, error) } return &tast.BinaryExpression{Lhs: lhs, Rhs: rhs, Operator: expr.Operator, Token: expr.Token, ResultType: resultType}, errors.Join(lhsErr, rhsErr) + case *ast.BlockExpression: + expressions := []tast.Expression{} + errs := []error{} + + for _, expr := range expr.Expressions { + newExpr, err := c.inferExpression(expr) + if err != nil { + errs = append(errs, err) + } else { + expressions = append(expressions, newExpr) + } + } + + var returnExpr tast.Expression + var returnType types.Type + if expr.ReturnExpression != nil { + expr, err := c.inferExpression(expr.ReturnExpression) + returnExpr = expr + if err != nil { + errs = append(errs, err) + } else { + returnType = returnExpr.Type() + } + } else { + returnType = types.Unit + } + + return &tast.BlockExpression{ + Token: expr.Token, + Expressions: expressions, + ReturnType: returnType, + ReturnExpression: returnExpr, + }, errors.Join(errs...) } return nil, fmt.Errorf("unhandled expression in type inferer") } diff --git a/types/types.go b/types/types.go index 322314a..fd0ebcf 100644 --- a/types/types.go +++ b/types/types.go @@ -15,11 +15,13 @@ type TypeId struct { } const ( - I64Id int64 = iota + UnitId int64 = iota + I64Id BoolId ) var ( + Unit = New(UnitId, "()") I64 = New(I64Id, "i64") Bool = New(BoolId, "bool") )