diff --git a/asm/amd64/amd64.go b/asm/amd64/amd64.go index a05389c..1a9d2df 100644 --- a/asm/amd64/amd64.go +++ b/asm/amd64/amd64.go @@ -84,12 +84,11 @@ type Opcode string const ( // Two operands - Mov Opcode = "mov" // Lhs: dst, Rhs: src, or better said intel syntax - Add Opcode = "add" - Sub Opcode = "sub" - Imul Opcode = "imul" - Cmp Opcode = "cmp" - SetCC Opcode = "setcc" + Mov Opcode = "mov" // Lhs: dst, Rhs: src, or better said intel syntax + Add Opcode = "add" + Sub Opcode = "sub" + Imul Opcode = "imul" + Cmp Opcode = "cmp" // One operand Idiv Opcode = "idiv" @@ -130,6 +129,27 @@ func (i *SimpleInstruction) InstructionString() string { return fmt.Sprintf("%s %s, %s", i.Opcode, i.Lhs.OperandString(Eight), i.Rhs.OperandString(Eight)) } +type Label string + +func (l Label) InstructionString() string { + return fmt.Sprintf("%s:", l) +} + +type JumpCCInstruction struct { + Cond CondCode + Dst string +} + +func (j *JumpCCInstruction) InstructionString() string { + return fmt.Sprintf("j%s %s", j.Cond, j.Dst) +} + +type JmpInstruction string + +func (j JmpInstruction) InstructionString() string { + return fmt.Sprintf("jmp %s", j) +} + type SetCCInstruction struct { Cond CondCode Dst Operand diff --git a/asm/amd64/codegen.go b/asm/amd64/codegen.go index fd4285e..f0474cc 100644 --- a/asm/amd64/codegen.go +++ b/asm/amd64/codegen.go @@ -74,6 +74,36 @@ func cgInstruction(i ttir.Instruction) []Instruction { } case *ttir.Binary: return cgBinary(i) + case ttir.Label: + return []Instruction{Label(i)} + case *ttir.JumpIfZero: + return []Instruction{ + &SimpleInstruction{ + Opcode: Cmp, + Lhs: toAsmOperand(i.Value), + Rhs: Imm(0), + }, + &JumpCCInstruction{ + Cond: Equal, + Dst: i.Label, + }, + } + case *ttir.JumpIfNotZero: + return []Instruction{ + &SimpleInstruction{ + Opcode: Cmp, + Lhs: toAsmOperand(i.Value), + Rhs: Imm(0), + }, + &JumpCCInstruction{ + Cond: NotEqual, + Dst: i.Label, + }, + } + case ttir.Jump: + return []Instruction{JmpInstruction(i)} + case *ttir.Copy: + return []Instruction{&SimpleInstruction{Opcode: Mov, Lhs: toAsmOperand(i.Dst), Rhs: toAsmOperand(i.Src)}} } return []Instruction{} @@ -190,6 +220,8 @@ func rpInstruction(i Instruction, r *replacePseudoPass) Instruction { Cond: i.Cond, Dst: pseudoToStack(i.Dst, r), } + case *JumpCCInstruction, JmpInstruction, Label: + return i } panic("invalid instruction") @@ -295,6 +327,8 @@ func fixupInstruction(i Instruction) []Instruction { return []Instruction{i} case *SetCCInstruction: + return []Instruction{i} + case *JumpCCInstruction, JmpInstruction, Label: return []Instruction{i} } diff --git a/asm/qbe/qbe.go b/asm/qbe/qbe.go index 57696e1..74393f8 100644 --- a/asm/qbe/qbe.go +++ b/asm/qbe/qbe.go @@ -10,6 +10,13 @@ import ( _ "embed" ) +var extraLabelId int64 = 0 + +func extraLabel() string { + extraLabelId += 1 + return fmt.Sprintf("qbe.extra.%d", extraLabelId) +} + //go:embed qbe_stub.asm var Stub string @@ -118,7 +125,30 @@ func emitInstruction(w io.Writer, i ttir.Instruction) error { if err := emitf(w, "\t%s =l %s %s, %s\n", emitOperand(i.Dst), inst, emitOperand(i.Lhs), emitOperand(i.Rhs)); err != nil { return err } - + case *ttir.Copy: + if err := emitf(w, "\t%s =l copy %s\n", emitOperand(i.Dst), emitOperand(i.Src)); err != nil { + return err + } + case ttir.Label: + if err := emitf(w, "@%s\n", string(i)); err != nil { + return err + } + case ttir.Jump: + if err := emitf(w, "\tjmp @%s\n", string(i)); err != nil { + return err + } + case *ttir.JumpIfNotZero: + after := extraLabel() + if err := emitf(w, "\tjnz %s, @%s, @%s\n@%s\n", emitOperand(i.Value), i.Label, after, after); err != nil { + return err + } + case *ttir.JumpIfZero: + after := extraLabel() + if err := emitf(w, "\tjnz %s, @%s, @%s\n@%s\n", emitOperand(i.Value), after, i.Label, after); err != nil { + return err + } + default: + panic("unkown instruction") } return nil diff --git a/ast/ast.go b/ast/ast.go index 01cdc5d..1d03240 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -165,3 +165,28 @@ func (be *BlockExpression) String() string { return builder.String() } + +type IfExpression struct { + Token token.Token // The 'if' token + Condition Expression + Then Expression + // Can be nil + Else Expression +} + +func (ie *IfExpression) expressionNode() {} +func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal } +func (ie *IfExpression) String() string { + var builder strings.Builder + + builder.WriteString("(if\n\t") + builder.WriteString(ie.Then.String()) + + if ie.Else != nil { + builder.WriteString(" else in ") + builder.WriteString(ie.Else.String()) + } + builder.WriteString(")") + + return builder.String() +} diff --git a/parser/parser.go b/parser/parser.go index 02955a1..d5ef3f5 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -56,6 +56,7 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefixFn(token.False, p.parseBooleanExpression) p.registerPrefixFn(token.OpenParen, p.parseGroupedExpression) p.registerPrefixFn(token.OpenBrack, p.parseBlockExpression) + p.registerPrefixFn(token.If, p.parseIfExpression) p.infixParseFns = make(map[token.TokenType]infixParseFn) p.registerInfixFn(token.Plus, p.parseBinaryExpression) @@ -334,3 +335,36 @@ func (p *Parser) parseBlockExpression() ast.Expression { return block } + +func (p *Parser) parseIfExpression() ast.Expression { + if ok, errExpr := p.expect(token.If); !ok { + return errExpr + } + + ifExpr := &ast.IfExpression{Token: p.curToken} + + p.nextToken() + ifExpr.Condition = p.parseExpression(PrecLowest) + + if p.peekTokenIs(token.OpenBrack) { + p.nextToken() + ifExpr.Then = p.parseBlockExpression() + } else { + if ok, errExpr := p.expectPeek(token.In); !ok { + return errExpr + } + + p.nextToken() + ifExpr.Then = p.parseExpression(PrecLowest) + } + + if p.peekTokenIs(token.Else) { + p.nextToken() + p.nextToken() + ifExpr.Else = p.parseExpression(PrecLowest) + } else { + ifExpr.Else = nil + } + + return ifExpr +} diff --git a/tast/tast.go b/tast/tast.go index af615be..7cbefc2 100644 --- a/tast/tast.go +++ b/tast/tast.go @@ -140,3 +140,32 @@ func (be *BlockExpression) String() string { return builder.String() } + +type IfExpression struct { + Token token.Token // The 'if' token + Condition Expression + Then Expression + // Can be nil + Else Expression + ReturnType types.Type +} + +func (ie *IfExpression) expressionNode() {} +func (ie *IfExpression) Type() types.Type { + return ie.ReturnType +} +func (ie *IfExpression) TokenLiteral() string { return ie.Token.Literal } +func (ie *IfExpression) String() string { + var builder strings.Builder + + builder.WriteString("(if\n\t") + builder.WriteString(ie.Then.String()) + + if ie.Else != nil { + builder.WriteString(" else in ") + builder.WriteString(ie.Else.String()) + } + builder.WriteString(")") + + return builder.String() +} diff --git a/test.tt b/test.tt index 7849cdf..0d87ac0 100644 --- a/test.tt +++ b/test.tt @@ -1,3 +1,5 @@ fn main() = { - 5 == 3 + if 3 == 3 + in 4 + else 3 }; diff --git a/token/token.go b/token/token.go index 0df8713..a09aaf5 100644 --- a/token/token.go +++ b/token/token.go @@ -16,9 +16,12 @@ type Token struct { } var keywords = map[string]TokenType{ - "fn": Fn, - "true": True, + "else": Else, "false": False, + "fn": Fn, + "if": If, + "in": In, + "true": True, } const ( @@ -48,9 +51,12 @@ const ( GreaterThanEqual TokenType = ">=" // Keywords - Fn TokenType = "FN" - True TokenType = "TRUE" + Else TokenType = "ELSE" False TokenType = "FALSE" + Fn TokenType = "FN" + If TokenType = "IF" + In TokenType = "IN" + True TokenType = "TRUE" ) func LookupKeyword(literal string) TokenType { diff --git a/ttir/emit.go b/ttir/emit.go index 5d8990c..e642859 100644 --- a/ttir/emit.go +++ b/ttir/emit.go @@ -7,11 +7,18 @@ import ( "robaertschi.xyz/robaertschi/tt/types" ) -var uniqueId int64 +var uniqueTempId int64 func temp() string { - uniqueId += 1 - return fmt.Sprintf("temp.%d", uniqueId) + uniqueTempId += 1 + return fmt.Sprintf("temp.%d", uniqueTempId) +} + +var uniqueLabelId int64 + +func tempLabel() string { + uniqueLabelId += 1 + return fmt.Sprintf("lbl.%d", uniqueLabelId) } func EmitProgram(program *tast.Program) *Program { @@ -82,6 +89,32 @@ func emitExpression(expr tast.Expression) (Operand, []Instruction) { } return value, instructions + case *tast.IfExpression: + // if (cond -> false jump to "else") { + // ... + // } jump to end of if + // else: else { + // ... + // } endOfIf: + elseLabel := tempLabel() + endOfIfLabel := tempLabel() + dst := &Var{Value: temp()} + + condDst, instructions := emitExpression(expr.Condition) + + instructions = append(instructions, &JumpIfZero{Value: condDst, Label: elseLabel}) + thenDst, thenInstructions := emitExpression(expr.Then) + instructions = append(instructions, thenInstructions...) + instructions = append(instructions, &Copy{Src: thenDst, Dst: dst}, Jump(endOfIfLabel)) + + instructions = append(instructions, Label(elseLabel)) + if expr.Else != nil { + elseDst, elseInstructions := emitExpression(expr.Else) + instructions = append(instructions, elseInstructions...) + instructions = append(instructions, &Copy{Src: elseDst, Dst: dst}) + } + instructions = append(instructions, Label(endOfIfLabel)) + return dst, instructions } panic("unhandled tast.Expression case in ir emitter") } diff --git a/ttir/ttir.go b/ttir/ttir.go index 65a3a43..06a62fc 100644 --- a/ttir/ttir.go +++ b/ttir/ttir.go @@ -67,6 +67,16 @@ func (b *Binary) String() string { } func (b *Binary) instruction() {} +type Copy struct { + Src Operand + Dst Operand +} + +func (c *Copy) String() string { + return fmt.Sprintf("%s = copy %s\n", c.Dst, c.Src) +} +func (c *Copy) instruction() {} + type JumpIfZero struct { Value Operand Label string diff --git a/typechecker/check.go b/typechecker/check.go index 13f29c4..8830525 100644 --- a/typechecker/check.go +++ b/typechecker/check.go @@ -7,6 +7,7 @@ import ( "robaertschi.xyz/robaertschi/tt/ast" "robaertschi.xyz/robaertschi/tt/tast" "robaertschi.xyz/robaertschi/tt/token" + "robaertschi.xyz/robaertschi/tt/types" ) type Checker struct { @@ -95,6 +96,26 @@ func (c *Checker) checkExpression(expr tast.Expression) error { errs = append(errs, c.checkExpression(expr.ReturnExpression)) } return errors.Join(errs...) + case *tast.IfExpression: + condErr := c.checkExpression(expr.Condition) + if condErr == nil { + if !expr.Condition.Type().IsSameType(types.Bool) { + condErr = c.error(expr.Token, "the condition in the if should be a boolean, but got %q", expr.Condition.Type().Name()) + } + } + thenErr := c.checkExpression(expr.Then) + + if expr.Else == nil { + return errors.Join(condErr, thenErr) + } + + elseErr := c.checkExpression(expr.Else) + if thenErr == nil && elseErr == nil { + if !expr.Then.Type().IsSameType(expr.Else.Type()) { + thenErr = c.error(expr.Token, "the then branch of type %q does not match with the else branch of type %q", expr.Then.Type().Name(), expr.Else.Type().Name()) + } + } + return errors.Join(condErr, thenErr, elseErr) } return fmt.Errorf("unhandled expression %T in type checker", expr) } diff --git a/typechecker/infer.go b/typechecker/infer.go index c1e61c5..72068f4 100644 --- a/typechecker/infer.go +++ b/typechecker/infer.go @@ -93,6 +93,17 @@ func (c *Checker) inferExpression(expr ast.Expression) (tast.Expression, error) ReturnType: returnType, ReturnExpression: returnExpr, }, errors.Join(errs...) + + case *ast.IfExpression: + cond, condErr := c.inferExpression(expr.Condition) + then, thenErr := c.inferExpression(expr.Then) + + if expr.Else != nil { + elseExpr, elseErr := c.inferExpression(expr.Else) + + return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: elseExpr, ReturnType: then.Type()}, errors.Join(condErr, thenErr, elseErr) + } + return &tast.IfExpression{Token: expr.Token, Condition: cond, Then: then, Else: nil, ReturnType: types.Unit}, errors.Join(condErr, thenErr) } return nil, fmt.Errorf("unhandled expression in type inferer") }