added equal and not equal, tests are needed for those

This commit is contained in:
Robin Bärtschi 2025-01-23 14:46:42 +01:00
parent d998ecfc42
commit 8e684b800c
12 changed files with 287 additions and 66 deletions

View File

@ -49,15 +49,24 @@ func (f *Function) Emit() string {
return builder.String() return builder.String()
} }
type CondCode string
const (
Equal CondCode = "e"
NotEqual CondCode = "ne"
)
type Opcode string type Opcode string
const ( const (
// Two operands // Two operands
Mov Opcode = "mov" // Lhs: dst, Rhs: src, or better said intel syntax Mov Opcode = "mov" // Lhs: dst, Rhs: src, or better said intel syntax
Add Opcode = "add" Add Opcode = "add"
Sub Opcode = "sub" Sub Opcode = "sub"
Imul Opcode = "imul" Imul Opcode = "imul"
Cmp Opcode = "cmp"
SetCC Opcode = "setcc"
// One operand // One operand
Idiv Opcode = "idiv" Idiv Opcode = "idiv"
@ -67,7 +76,11 @@ const (
Cdq Opcode = "cdq" Cdq Opcode = "cdq"
) )
type Instruction struct { type Instruction interface {
InstructionString() string
}
type SimpleInstruction struct {
Opcode Opcode Opcode Opcode
// Dst // Dst
Lhs Operand Lhs Operand
@ -75,8 +88,7 @@ type Instruction struct {
Rhs Operand Rhs Operand
} }
func (i *Instruction) InstructionString() string { func (i *SimpleInstruction) InstructionString() string {
if i.Opcode == Ret { if i.Opcode == Ret {
return fmt.Sprintf("mov rsp, rbp\n pop rbp\n ret\n") return fmt.Sprintf("mov rsp, rbp\n pop rbp\n ret\n")
} }
@ -95,6 +107,15 @@ func (i *Instruction) InstructionString() string {
return fmt.Sprintf("%s %s, %s", i.Opcode, i.Lhs.OperandString(Eight), i.Rhs.OperandString(Eight)) return fmt.Sprintf("%s %s, %s", i.Opcode, i.Lhs.OperandString(Eight), i.Rhs.OperandString(Eight))
} }
type SetCCInstruction struct {
Cond CondCode
Dst Operand
}
func (si *SetCCInstruction) InstructionString() string {
return fmt.Sprintf("set%s %s", si.Cond, si.Dst.OperandString(Eight))
}
type OperandSize int type OperandSize int
type Operand interface { type Operand interface {

View File

@ -1,6 +1,7 @@
package amd64 package amd64
import ( import (
_ "embed"
"strings" "strings"
"testing" "testing"
@ -21,6 +22,9 @@ func TestOperands(t *testing.T) {
} }
} }
//go:embed basic_test.txt
var basicTest string
func TestCodegen(t *testing.T) { func TestCodegen(t *testing.T) {
program := &ttir.Program{ program := &ttir.Program{
Functions: []ttir.Function{ Functions: []ttir.Function{
@ -55,8 +59,7 @@ func TestCodegen(t *testing.T) {
expectProgram(t, expectedProgram, actualProgram) expectProgram(t, expectedProgram, actualProgram)
actual := actualProgram.Emit() actual := actualProgram.Emit()
expected := basicTest
expected := executableAsmHeader + "main:\n mov rax, 0\n ret\n"
if strings.Trim(actual, " \n\t") != strings.Trim(expected, " \n\t") { if strings.Trim(actual, " \n\t") != strings.Trim(expected, " \n\t") {
t.Errorf("Expected program to be:\n>>%s<<\nbut got:\n>>%s<<\n", expected, actual) t.Errorf("Expected program to be:\n>>%s<<\nbut got:\n>>%s<<\n", expected, actual)
} }
@ -130,6 +133,26 @@ func expectOperand(t *testing.T, expected Operand, actual Operand) {
if expected != actual { if expected != actual {
t.Errorf("Expected Immediate %q but got %q", expected, actual) t.Errorf("Expected Immediate %q but got %q", expected, actual)
} }
case Stack:
actual, ok := actual.(Stack)
if !ok {
t.Errorf("Expected Stack but got %T", actual)
}
if expected != actual {
t.Errorf("Expected Stack value %q but got %q", expected, actual)
}
case Pseudo:
actual, ok := actual.(Pseudo)
if !ok {
t.Errorf("Expected Stack but got %T", actual)
}
if expected != actual {
t.Errorf("Expected Stack value %q but got %q", expected, actual)
}
default: default:
t.Errorf("Unknown operand type %T", expected) t.Errorf("Unknown operand type %T", expected)
} }

17
asm/amd64/basic_test.txt Normal file
View File

@ -0,0 +1,17 @@
format ELF64 executable
segment readable executable
entry _start
_start:
call main
mov rdi, rax
mov rax, 60
syscall
main:
push rbp
mov rbp, rsp
add rsp, 0
mov rax, 0
mov rsp, rbp
pop rbp
ret

View File

@ -52,12 +52,12 @@ func cgInstruction(i ttir.Instruction) []Instruction {
switch i := i.(type) { switch i := i.(type) {
case *ttir.Ret: case *ttir.Ret:
return []Instruction{ return []Instruction{
{ &SimpleInstruction{
Opcode: Mov, Opcode: Mov,
Lhs: AX, Lhs: AX,
Rhs: toAsmOperand(i.Op), Rhs: toAsmOperand(i.Op),
}, },
{ &SimpleInstruction{
Opcode: Ret, Opcode: Ret,
}, },
} }
@ -70,6 +70,32 @@ func cgInstruction(i ttir.Instruction) []Instruction {
func cgBinary(b *ttir.Binary) []Instruction { func cgBinary(b *ttir.Binary) []Instruction {
switch b.Operator { switch b.Operator {
case ast.Equal, ast.NotEqual:
var condCode CondCode
switch b.Operator {
case ast.Equal:
condCode = Equal
case ast.NotEqual:
condCode = NotEqual
}
return []Instruction{
&SimpleInstruction{
Opcode: Cmp,
Lhs: toAsmOperand(b.Lhs),
Rhs: toAsmOperand(b.Rhs),
},
&SimpleInstruction{
Opcode: Mov,
Lhs: toAsmOperand(b.Dst),
Rhs: Imm(0),
},
&SetCCInstruction{
Cond: condCode,
Dst: toAsmOperand(b.Dst),
},
}
case ast.Add, ast.Subtract, ast.Multiply: case ast.Add, ast.Subtract, ast.Multiply:
var opcode Opcode var opcode Opcode
switch b.Operator { switch b.Operator {
@ -82,15 +108,15 @@ func cgBinary(b *ttir.Binary) []Instruction {
} }
return []Instruction{ return []Instruction{
{Opcode: Mov, Lhs: toAsmOperand(b.Dst), Rhs: toAsmOperand(b.Lhs)}, &SimpleInstruction{Opcode: Mov, Lhs: toAsmOperand(b.Dst), Rhs: toAsmOperand(b.Lhs)},
{Opcode: opcode, Lhs: toAsmOperand(b.Dst), Rhs: toAsmOperand(b.Rhs)}, &SimpleInstruction{Opcode: opcode, Lhs: toAsmOperand(b.Dst), Rhs: toAsmOperand(b.Rhs)},
} }
case ast.Divide: case ast.Divide:
return []Instruction{ return []Instruction{
{Opcode: Mov, Lhs: Register(AX), Rhs: toAsmOperand(b.Lhs)}, &SimpleInstruction{Opcode: Mov, Lhs: Register(AX), Rhs: toAsmOperand(b.Lhs)},
{Opcode: Cdq}, &SimpleInstruction{Opcode: Cdq},
{Opcode: Idiv, Lhs: toAsmOperand(b.Rhs)}, &SimpleInstruction{Opcode: Idiv, Lhs: toAsmOperand(b.Rhs)},
{Opcode: Mov, Lhs: toAsmOperand(b.Dst), Rhs: Register(AX)}, &SimpleInstruction{Opcode: Mov, Lhs: toAsmOperand(b.Dst), Rhs: Register(AX)},
} }
} }
@ -128,15 +154,26 @@ func rpFunction(f Function) Function {
func rpInstruction(i Instruction, r *replacePseudoPass) Instruction { func rpInstruction(i Instruction, r *replacePseudoPass) Instruction {
newInstruction := Instruction{Opcode: i.Opcode} switch i := i.(type) {
if i.Lhs != nil { case *SimpleInstruction:
newInstruction.Lhs = pseudoToStack(i.Lhs, r)
} newInstruction := &SimpleInstruction{Opcode: i.Opcode}
if i.Rhs != nil { if i.Lhs != nil {
newInstruction.Rhs = pseudoToStack(i.Rhs, r) newInstruction.Lhs = pseudoToStack(i.Lhs, r)
}
if i.Rhs != nil {
newInstruction.Rhs = pseudoToStack(i.Rhs, r)
}
return newInstruction
case *SetCCInstruction:
return &SetCCInstruction{
Cond: i.Cond,
Dst: pseudoToStack(i.Dst, r),
}
} }
return newInstruction panic("invalid instruction")
} }
func pseudoToStack(op Operand, r *replacePseudoPass) Operand { func pseudoToStack(op Operand, r *replacePseudoPass) Operand {
@ -177,40 +214,70 @@ func fixupFunction(f Function) Function {
func fixupInstruction(i Instruction) []Instruction { func fixupInstruction(i Instruction) []Instruction {
switch i.Opcode { switch i := i.(type) {
case Mov: case *SimpleInstruction:
if lhs, ok := i.Lhs.(Stack); ok { switch i.Opcode {
if rhs, ok := i.Rhs.(Stack); ok { case Mov:
if lhs, ok := i.Lhs.(Stack); ok {
if rhs, ok := i.Rhs.(Stack); ok {
return []Instruction{
&SimpleInstruction{Opcode: Mov, Lhs: Register(R10), Rhs: rhs},
&SimpleInstruction{Opcode: Mov, Lhs: lhs, Rhs: Register(R10)},
}
}
}
case Imul:
if lhs, ok := i.Lhs.(Stack); ok {
return []Instruction{ return []Instruction{
{Opcode: Mov, Lhs: Register(R10), Rhs: rhs}, &SimpleInstruction{Opcode: Mov, Lhs: Register(R11), Rhs: lhs},
{Opcode: Mov, Lhs: lhs, Rhs: Register(R10)}, &SimpleInstruction{Opcode: Imul, Lhs: Register(R11), Rhs: i.Rhs},
&SimpleInstruction{Opcode: Mov, Lhs: lhs, Rhs: Register(R11)},
}
}
fallthrough
case Add, Sub, Idiv /* Imul (fallthrough) */ :
if lhs, ok := i.Lhs.(Stack); ok {
if rhs, ok := i.Rhs.(Stack); ok {
return []Instruction{
&SimpleInstruction{Opcode: Mov, Lhs: Register(R10), Rhs: rhs},
&SimpleInstruction{Opcode: i.Opcode, Lhs: lhs, Rhs: Register(R10)},
}
}
} else if lhs, ok := i.Lhs.(Imm); ok && i.Opcode == Idiv {
return []Instruction{
&SimpleInstruction{Opcode: Mov, Lhs: Register(R10), Rhs: lhs},
&SimpleInstruction{Opcode: Idiv, Lhs: Register(R10)},
}
}
case Cmp:
if lhs, ok := i.Lhs.(Stack); ok {
if rhs, ok := i.Rhs.(Stack); ok {
return []Instruction{
&SimpleInstruction{Opcode: Mov, Lhs: Register(R10), Rhs: rhs},
&SimpleInstruction{Opcode: i.Opcode, Lhs: lhs, Rhs: Register(R10)},
}
}
} else if rhs, ok := i.Rhs.(Imm); ok {
return []Instruction{
&SimpleInstruction{
Opcode: Mov,
Lhs: Register(R11),
Rhs: Imm(rhs),
},
&SimpleInstruction{
Opcode: Cmp,
Lhs: i.Lhs,
Rhs: Register(R11),
},
} }
} }
} }
case Imul:
if lhs, ok := i.Lhs.(Stack); ok { return []Instruction{i}
return []Instruction{ case *SetCCInstruction:
{Opcode: Mov, Lhs: Register(R11), Rhs: lhs},
{Opcode: Imul, Lhs: Register(R11), Rhs: i.Rhs}, return []Instruction{i}
{Opcode: Mov, Lhs: lhs, Rhs: Register(R11)},
}
}
fallthrough
case Add, Sub, Idiv /* Imul (fallthrough) */ :
if lhs, ok := i.Lhs.(Stack); ok {
if rhs, ok := i.Rhs.(Stack); ok {
return []Instruction{
{Opcode: Mov, Lhs: Register(R10), Rhs: rhs},
{Opcode: i.Opcode, Lhs: lhs, Rhs: Register(R10)},
}
}
} else if lhs, ok := i.Lhs.(Imm); ok && i.Opcode == Idiv {
return []Instruction{
{Opcode: Mov, Lhs: Register(R10), Rhs: lhs},
{Opcode: Idiv, Lhs: Register(R10)},
}
}
} }
return []Instruction{i} panic("invalid instruction")
} }

View File

@ -82,6 +82,8 @@ const (
Subtract Subtract
Multiply Multiply
Divide Divide
Equal
NotEqual
) )
func (bo BinaryOperator) SymbolString() string { func (bo BinaryOperator) SymbolString() string {
@ -94,6 +96,10 @@ func (bo BinaryOperator) SymbolString() string {
return "*" return "*"
case Divide: case Divide:
return "/" return "/"
case Equal:
return "=="
case NotEqual:
return "!="
} }
return "<INVALID BINARY OPERATOR>" return "<INVALID BINARY OPERATOR>"
} }

12
language.md Normal file
View File

@ -0,0 +1,12 @@
# tt
## Syntax
```tt
// Return type is i64
fn main() = {
let i = 34;
i
};
```

View File

@ -66,6 +66,14 @@ func (l *Lexer) NextToken() token.Token {
case ';': case ';':
tok = l.newToken(token.Semicolon) tok = l.newToken(token.Semicolon)
case '=': case '=':
if l.peekByte() == '=' {
pos := l.position
l.readChar()
l.readChar()
tok.Type = token.DoubleEqual
tok.Literal = l.input[pos:l.position]
return tok
}
tok = l.newToken(token.Equal) tok = l.newToken(token.Equal)
case '(': case '(':
tok = l.newToken(token.OpenParen) tok = l.newToken(token.OpenParen)
@ -79,6 +87,16 @@ func (l *Lexer) NextToken() token.Token {
tok = l.newToken(token.Asterisk) tok = l.newToken(token.Asterisk)
case '/': case '/':
tok = l.newToken(token.Slash) tok = l.newToken(token.Slash)
case '!':
if l.peekByte() == '=' {
pos := l.position
l.readChar()
l.readChar()
tok.Type = token.NotEqual
tok.Literal = l.input[pos:l.position]
return tok
}
tok = l.newToken(token.Illegal)
case -1: case -1:
tok.Literal = "" tok.Literal = ""
tok.Type = token.Eof tok.Type = token.Eof
@ -140,6 +158,14 @@ func (l *Lexer) readChar() (err error) {
return return
} }
func (l *Lexer) peekByte() byte {
if l.readPosition < len(l.input) {
return l.input[l.readPosition]
} else {
return 0
}
}
func (l *Lexer) readIdentifier() string { func (l *Lexer) readIdentifier() string {
startPos := l.position startPos := l.position

View File

@ -40,7 +40,7 @@ func runLexerTest(t *testing.T, test lexerTest) {
func TestBasicFunctionality(t *testing.T) { func TestBasicFunctionality(t *testing.T) {
runLexerTest(t, lexerTest{ runLexerTest(t, lexerTest{
input: "fn main() = 0;", input: "fn main() = 0 + 3;",
expectedToken: []token.Token{ expectedToken: []token.Token{
{Type: token.Fn, Literal: "fn"}, {Type: token.Fn, Literal: "fn"},
{Type: token.Ident, Literal: "main"}, {Type: token.Ident, Literal: "main"},
@ -48,6 +48,8 @@ func TestBasicFunctionality(t *testing.T) {
{Type: token.CloseParen, Literal: ")"}, {Type: token.CloseParen, Literal: ")"},
{Type: token.Equal, Literal: "="}, {Type: token.Equal, Literal: "="},
{Type: token.Int, Literal: "0"}, {Type: token.Int, Literal: "0"},
{Type: token.Plus, Literal: "+"},
{Type: token.Int, Literal: "3"},
{Type: token.Semicolon, Literal: ";"}, {Type: token.Semicolon, Literal: ";"},
{Type: token.Eof, Literal: ""}, {Type: token.Eof, Literal: ""},
}, },

View File

@ -51,6 +51,8 @@ func New(l *lexer.Lexer) *Parser {
p.registerInfixFn(token.Minus, p.parseBinaryExpression) p.registerInfixFn(token.Minus, p.parseBinaryExpression)
p.registerInfixFn(token.Asterisk, p.parseBinaryExpression) p.registerInfixFn(token.Asterisk, p.parseBinaryExpression)
p.registerInfixFn(token.Slash, p.parseBinaryExpression) p.registerInfixFn(token.Slash, p.parseBinaryExpression)
p.registerInfixFn(token.Equal, p.parseBinaryExpression)
p.registerInfixFn(token.NotEqual, p.parseBinaryExpression)
p.nextToken() p.nextToken()
p.nextToken() p.nextToken()
@ -239,6 +241,10 @@ func (p *Parser) parseBinaryExpression(lhs ast.Expression) ast.Expression {
op = ast.Multiply op = ast.Multiply
case token.Slash: case token.Slash:
op = ast.Divide op = ast.Divide
case token.DoubleEqual:
op = ast.Equal
case token.NotEqual:
op = ast.NotEqual
default: default:
return p.exprError(p.curToken, "invalid token for binary expression %s", p.curToken.Type) return p.exprError(p.curToken, "invalid token for binary expression %s", p.curToken.Type)
} }

View File

@ -30,10 +30,14 @@ const (
Equal TokenType = "=" Equal TokenType = "="
OpenParen TokenType = "(" OpenParen TokenType = "("
CloseParen TokenType = ")" CloseParen TokenType = ")"
Plus TokenType = "+"
Minus TokenType = "-" // Binary Operators
Asterisk TokenType = "*" Plus TokenType = "+"
Slash TokenType = "/" Minus TokenType = "-"
Asterisk TokenType = "*"
Slash TokenType = "/"
DoubleEqual TokenType = "=="
NotEqual TokenType = "!="
// Keywords // Keywords
Fn TokenType = "FN" Fn TokenType = "FN"

View File

@ -3,6 +3,7 @@ package ttir
import ( import (
"fmt" "fmt"
"robaertschi.xyz/robaertschi/tt/ast"
"robaertschi.xyz/robaertschi/tt/tast" "robaertschi.xyz/robaertschi/tt/tast"
) )
@ -41,13 +42,15 @@ func emitExpression(expr tast.Expression) (Operand, []Instruction) {
case *tast.IntegerExpression: case *tast.IntegerExpression:
return &Constant{Value: expr.Value}, []Instruction{} return &Constant{Value: expr.Value}, []Instruction{}
case *tast.BinaryExpression: case *tast.BinaryExpression:
lhsDst, instructions := emitExpression(expr.Lhs) switch expr.Operator {
rhsDst, rhsInstructions := emitExpression(expr.Rhs) default:
instructions = append(instructions, rhsInstructions...) lhsDst, instructions := emitExpression(expr.Lhs)
dst := &Var{Value: temp()} rhsDst, rhsInstructions := emitExpression(expr.Rhs)
instructions = append(instructions, &Binary{Operator: expr.Operator, Lhs: lhsDst, Rhs: rhsDst, Dst: dst}) instructions = append(instructions, rhsInstructions...)
return dst, instructions dst := &Var{Value: temp()}
instructions = append(instructions, &Binary{Operator: expr.Operator, Lhs: lhsDst, Rhs: rhsDst, Dst: dst})
return dst, instructions
}
} }
panic("unhandled tast.Expression case in ir emitter") panic("unhandled tast.Expression case in ir emitter")
} }

View File

@ -60,6 +60,40 @@ func (b *Binary) String() string {
} }
func (b *Binary) instruction() {} func (b *Binary) instruction() {}
type JumpIfZero struct {
Value Operand
Label string
}
func (jiz *JumpIfZero) String() string {
return fmt.Sprintf("jz %v, %v\n", jiz.Value, jiz.Label)
}
func (jiz *JumpIfZero) instruction() {}
type JumpIfNotZero struct {
Value Operand
Label string
}
func (jiz *JumpIfNotZero) String() string {
return fmt.Sprintf("jnz %v, %v\n", jiz.Value, jiz.Label)
}
func (jiz *JumpIfNotZero) instruction() {}
type Jump string
func (j Jump) String() string {
return fmt.Sprintf("jmp %v\n", string(j))
}
func (j Jump) instruction() {}
type Label string
func (l Label) String() string {
return fmt.Sprintf("%v:\n", string(l))
}
func (l Label) instruction() {}
type Operand interface { type Operand interface {
String() string String() string
operand() operand()