diff --git a/crates/mun_codegen/src/ir/function.rs b/crates/mun_codegen/src/ir/function.rs index 440095c1c..acce5c0ce 100644 --- a/crates/mun_codegen/src/ir/function.rs +++ b/crates/mun_codegen/src/ir/function.rs @@ -1,12 +1,10 @@ use super::try_convert_any_to_basic; use crate::ir::dispatch_table::DispatchTable; -use crate::values::{ - BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, InstructionOpcode, IntValue, -}; +use crate::values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue}; use crate::{IrDatabase, Module, OptimizationLevel}; use inkwell::builder::Builder; use inkwell::passes::{PassManager, PassManagerBuilder}; -use inkwell::types::{AnyTypeEnum, BasicTypeEnum}; +use inkwell::types::AnyTypeEnum; use inkwell::{FloatPredicate, IntPredicate}; use mun_hir::{ self as hir, ArithOp, BinaryOp, Body, CmpOp, Expr, ExprId, HirDisplay, InferenceResult, @@ -145,7 +143,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { fn gen_expr(&mut self, expr: ExprId) -> Option { let body = self.body.clone(); - let mut value = match &body[expr] { + let value = match &body[expr] { &Expr::Block { ref statements, tail, @@ -192,22 +190,27 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { ref callee, ref args, } => self.gen_call(*callee, &args).try_as_basic_value().left(), + Expr::If { + condition, + then_branch, + else_branch, + } => self.gen_if(expr, *condition, *then_branch, *else_branch), _ => unreachable!("unimplemented expr type"), }; - // Check expected type or perform implicit cast - value = value.map(|value| { - match ( - value.get_type(), - try_convert_any_to_basic(self.db.type_ir(self.infer[expr].clone())), - ) { - (BasicTypeEnum::IntType(_), Some(target @ BasicTypeEnum::FloatType(_))) => self - .builder - .build_cast(InstructionOpcode::SIToFP, value, target, "implicit_cast"), - (a, Some(b)) if a == b => value, - _ => unreachable!("could not perform implicit cast"), - } - }); + // // Check expected type or perform implicit cast + // value = value.map(|value| { + // match ( + // value.get_type(), + // try_convert_any_to_basic(self.db.type_ir(self.infer[expr].clone())), + // ) { + // (BasicTypeEnum::IntType(_), Some(target @ BasicTypeEnum::FloatType(_))) => self + // .builder + // .build_cast(InstructionOpcode::SIToFP, value, target, "implicit_cast"), + // (a, Some(b)) if a == b => value, + // _ => unreachable!("could not perform implicit cast"), + // } + // }); value } @@ -217,8 +220,8 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { fn new_alloca_builder(&self) -> Builder { let temp_builder = Builder::create(); let block = self - .builder - .get_insert_block() + .fn_value + .get_first_basic_block() .expect("at this stage there must be a block"); if let Some(first_instruction) = block.get_first_instruction() { temp_builder.position_before(&first_instruction); @@ -349,7 +352,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { .build_float_compare(predicate, lhs, rhs, name) .into() } - _ => unreachable!(), + _ => unreachable!(format!("Operator {:?} is not implemented for float", op)), } } @@ -365,8 +368,6 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { BinaryOp::ArithOp(ArithOp::Multiply) => { self.builder.build_int_mul(lhs, rhs, "mul").into() } - // BinaryOp::Remainder => Some(self.gen_remainder(lhs, rhs)), - // BinaryOp::Power =>, // BinaryOp::Assign, // BinaryOp::AddAssign, // BinaryOp::SubtractAssign, @@ -399,7 +400,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { .build_int_compare(predicate, lhs, rhs, name) .into() } - _ => unreachable!(), + _ => unreachable!(format!("Operator {:?} is not implemented for integer", op)), } } @@ -436,6 +437,75 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { .build_call(*llvm_function, &args, &function.name(self.db).to_string()) } } + + /// Generates IR for an if statement. + fn gen_if( + &mut self, + _expr: ExprId, + condition: ExprId, + then_branch: ExprId, + else_branch: Option, + ) -> Option { + // Generate IR for the condition + let condition_ir = self + .gen_expr(condition) + .expect("condition must have a value") + .into_int_value(); + + // Generate the code blocks to branch to + let context = self.module.get_context(); + let mut then_block = context.append_basic_block(&self.fn_value, "then"); + let else_block_and_expr = match &else_branch { + Some(else_branch) => Some(( + context.append_basic_block(&self.fn_value, "else"), + else_branch, + )), + None => None, + }; + let merge_block = context.append_basic_block(&self.fn_value, "if_merge"); + + // Build the actual branching IR for the if statement + let else_block = else_block_and_expr + .as_ref() + .map(|e| &e.0) + .unwrap_or(&merge_block); + self.builder + .build_conditional_branch(condition_ir, &then_block, else_block); + + // Fill the then block + self.builder.position_at_end(&then_block); + let then_block_ir = self.gen_expr(then_branch); + self.builder.build_unconditional_branch(&merge_block); + then_block = self.builder.get_insert_block().unwrap(); + + // Fill the else block, if it exists and get the result back + let else_ir_and_block = if let Some((else_block, else_branch)) = else_block_and_expr { + else_block.move_after(&then_block); + self.builder.position_at_end(&else_block); + let result_ir = self.gen_expr(*else_branch); + self.builder.build_unconditional_branch(&merge_block); + result_ir.map(|res| (res, self.builder.get_insert_block().unwrap())) + } else { + None + }; + + // Create merge block + merge_block.move_after(&self.builder.get_insert_block().unwrap()); + self.builder.position_at_end(&merge_block); + + // Construct phi block if a value was returned + if let Some(then_block_ir) = then_block_ir { + if let Some((else_block_ir, else_block)) = else_ir_and_block { + let phi = self.builder.build_phi(then_block_ir.get_type(), "iftmp"); + phi.add_incoming(&[(&then_block_ir, &then_block), (&else_block_ir, &else_block)]); + Some(phi.as_basic_value()) + } else { + Some(then_block_ir) + } + } else { + None + } + } } trait OptName { diff --git a/crates/mun_codegen/src/snapshots/test__fibonacci.snap b/crates/mun_codegen/src/snapshots/test__fibonacci.snap new file mode 100644 index 000000000..3b3d9de42 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__fibonacci.snap @@ -0,0 +1,31 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn fibonacci(n:int):int {\n if n <= 1 {\n n\n } else {\n fibonacci(n-1) + fibonacci(n-2)\n }\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +%DispatchTable = type { i64 (i64)* } + +@dispatchTable = global %DispatchTable { i64 (i64)* @fibonacci } + +define i64 @fibonacci(i64 %n) { +body: + %lesseq = icmp sle i64 %n, 1 + br i1 %lesseq, label %if_merge, label %else + +else: ; preds = %body + %sub = sub i64 %n, 1 + %fibonacci_ptr = load i64 (i64)*, i64 (i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0) + %fibonacci = call i64 %fibonacci_ptr(i64 %sub) + %sub1 = sub i64 %n, 2 + %fibonacci_ptr2 = load i64 (i64)*, i64 (i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0) + %fibonacci3 = call i64 %fibonacci_ptr2(i64 %sub1) + %add = add i64 %fibonacci, %fibonacci3 + br label %if_merge + +if_merge: ; preds = %body, %else + %iftmp = phi i64 [ %add, %else ], [ %n, %body ] + ret i64 %iftmp +} + diff --git a/crates/mun_codegen/src/snapshots/test__if_statement.snap b/crates/mun_codegen/src/snapshots/test__if_statement.snap new file mode 100644 index 000000000..a65cad628 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__if_statement.snap @@ -0,0 +1,28 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn foo(a:int):int {\n let b = if a > 3 {\n let c = if a > 4 {\n a+1\n } else {\n a+3\n }\n c\n } else {\n a-1\n }\n b\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +define i64 @foo(i64 %a) { +body: + %greater = icmp sgt i64 %a, 3 + br i1 %greater, label %then, label %else + +then: ; preds = %body + %greater1 = icmp sgt i64 %a, 4 + %add = add i64 %a, 1 + %add5 = add i64 %a, 3 + %iftmp = select i1 %greater1, i64 %add, i64 %add5 + br label %if_merge + +else: ; preds = %body + %sub = sub i64 %a, 1 + br label %if_merge + +if_merge: ; preds = %else, %then + %iftmp7 = phi i64 [ %iftmp, %then ], [ %sub, %else ] + ret i64 %iftmp7 +} + diff --git a/crates/mun_codegen/src/test.rs b/crates/mun_codegen/src/test.rs index f21e449e1..5ff5cb0ed 100644 --- a/crates/mun_codegen/src/test.rs +++ b/crates/mun_codegen/src/test.rs @@ -6,45 +6,6 @@ use mun_hir::SourceDatabase; use std::cell::RefCell; use std::sync::Arc; -fn test_snapshot(text: &str) { - let text = text.trim().replace("\n ", "\n"); - - let (db, file_id) = MockDatabase::with_single_file(&text); - - let line_index: Arc = db.line_index(file_id); - let messages = RefCell::new(Vec::new()); - let mut sink = DiagnosticSink::new(|diag| { - let line_col = line_index.line_col(diag.highlight_range().start()); - messages.borrow_mut().push(format!( - "error {}:{}: {}", - line_col.line + 1, - line_col.col + 1, - diag.message() - )); - }); - if let Some(module) = Module::package_modules(&db) - .iter() - .find(|m| m.file_id() == file_id) - { - module.diagnostics(&db, &mut sink) - } - drop(sink); - let messages = messages.into_inner(); - - let name = if !messages.is_empty() { - messages.join("\n") - } else { - format!( - "{}", - db.module_ir(file_id) - .llvm_module - .print_to_string() - .to_string() - ) - }; - insta::assert_snapshot!(insta::_macro_support::AutoName, name, &text); -} - #[test] fn function() { test_snapshot( @@ -185,3 +146,78 @@ fn equality_operands() { "#, ); } + +#[test] +fn if_statement() { + test_snapshot( + r#" + fn foo(a:int):int { + let b = if a > 3 { + let c = if a > 4 { + a+1 + } else { + a+3 + } + c + } else { + a-1 + } + b + } + "#, + ) +} + +#[test] +fn fibonacci() { + test_snapshot( + r#" + fn fibonacci(n:int):int { + if n <= 1 { + n + } else { + fibonacci(n-1) + fibonacci(n-2) + } + } + "#, + ) +} + +fn test_snapshot(text: &str) { + let text = text.trim().replace("\n ", "\n"); + + let (db, file_id) = MockDatabase::with_single_file(&text); + + let line_index: Arc = db.line_index(file_id); + let messages = RefCell::new(Vec::new()); + let mut sink = DiagnosticSink::new(|diag| { + let line_col = line_index.line_col(diag.highlight_range().start()); + messages.borrow_mut().push(format!( + "error {}:{}: {}", + line_col.line + 1, + line_col.col + 1, + diag.message() + )); + }); + if let Some(module) = Module::package_modules(&db) + .iter() + .find(|m| m.file_id() == file_id) + { + module.diagnostics(&db, &mut sink) + } + drop(sink); + let messages = messages.into_inner(); + + let name = if !messages.is_empty() { + messages.join("\n") + } else { + format!( + "{}", + db.module_ir(file_id) + .llvm_module + .print_to_string() + .to_string() + ) + }; + insta::assert_snapshot!(insta::_macro_support::AutoName, name, &text); +} diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index 0bfa98498..6388d8a9f 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -207,3 +207,31 @@ fn booleans() { false ); } + +#[test] +fn fibonacci() { + let compile_result = compile( + r#" + fn fibonacci(n:int):int { + if n <= 1 { + n + } else { + fibonacci(n-1) + fibonacci(n-2) + } + } + "#, + ); + let mut runtime = compile_result.new_runtime(); + assert_eq!( + MunRuntime::invoke_fn1::(&mut runtime, "fibonacci", 5).unwrap(), + 5 + ); + assert_eq!( + MunRuntime::invoke_fn1::(&mut runtime, "fibonacci", 11).unwrap(), + 89 + ); + assert_eq!( + MunRuntime::invoke_fn1::(&mut runtime, "fibonacci", 16).unwrap(), + 987 + ); +}