diff --git a/crates/mun_codegen/src/ir/body.rs b/crates/mun_codegen/src/ir/body.rs index 20bd7c1fe..ce95d50ac 100644 --- a/crates/mun_codegen/src/ir/body.rs +++ b/crates/mun_codegen/src/ir/body.rs @@ -35,6 +35,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { function_map: &'a HashMap, dispatch_table: &'b DispatchTable, active_loop: Option, + hir_function: hir::Function, } impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { @@ -69,6 +70,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { function_map, dispatch_table, active_loop: None, + hir_function, } } @@ -107,11 +109,20 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Construct a return statement from the returned value of the body if a return is expected // in the first place. If the return type of the body is `never` there is no need to // generate a return statement. - let ret_type = &self.infer[self.body.body_expr()]; - if let Some(value) = ret_value { - self.builder.build_return(Some(&value)); - } else if !ret_type.is_never() { - self.builder.build_return(None); + let block_ret_type = &self.infer[self.body.body_expr()]; + let fn_ret_type = self + .hir_function + .ty(self.db) + .callable_sig(self.db) + .unwrap() + .ret() + .clone(); + if !block_ret_type.is_never() { + if fn_ret_type.is_empty() { + self.builder.build_return(None); + } else if let Some(value) = ret_value { + self.builder.build_return(Some(&value)); + } } } @@ -143,6 +154,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } => self.gen_if(expr, *condition, *then_branch, *else_branch), Expr::Return { expr: ret_expr } => self.gen_return(expr, *ret_expr), Expr::Loop { body } => self.gen_loop(expr, *body), + Expr::While { condition, body } => self.gen_while(expr, *condition, *body), Expr::Break { expr: break_expr } => self.gen_break(expr, *break_expr), _ => unimplemented!("unimplemented expr type {:?}", &body[expr]), } @@ -178,6 +190,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } } + /// Constructs an empty struct value e.g. `{}` + fn gen_empty(&mut self) -> BasicValueEnum { + self.module.get_context().const_struct(&[], false).into() + } + /// Generates IR for the specified block expression. fn gen_block( &mut self, @@ -193,16 +210,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { self.gen_let_statement(*pat, *initializer); } Statement::Expr(expr) => { - self.gen_expr(*expr); - // No need to generate code after a statement that has a `never` return type. - if self.infer[*expr].is_never() { - return None; - } + self.gen_expr(*expr)?; } }; } tail.and_then(|expr| self.gen_expr(expr)) + .or_else(|| Some(self.gen_empty())) } /// Constructs a builder that should be used to emit an `alloca` instruction. These instructions @@ -365,7 +379,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { }; let place = self.gen_place_expr(lhs_expr); self.builder.build_store(place, rhs); - None + Some(self.gen_empty()) } _ => unimplemented!("Operator {:?} is not implemented for float", op), } @@ -422,7 +436,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { }; let place = self.gen_place_expr(lhs_expr); self.builder.build_store(place, rhs); - None + Some(self.gen_empty()) } _ => unreachable!(format!("Operator {:?} is not implemented for integer", op)), } @@ -507,10 +521,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { else_branch: Option, ) -> Option { // Generate IR for the condition - let condition_ir = self - .gen_expr(condition) - .expect("condition must have a value") - .into_int_value(); + let condition_ir = self.gen_expr(condition)?.into_int_value(); // Generate the code blocks to branch to let context = self.module.get_context(); @@ -570,7 +581,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { Some(then_block_ir) } } else { - None + Some(self.gen_empty()) } } @@ -600,34 +611,92 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { None } - fn gen_loop(&mut self, _expr: ExprId, body_expr: ExprId) -> Option { - let context = self.module.get_context(); - let loop_block = context.append_basic_block(&self.fn_value, "loop"); - let exit_block = context.append_basic_block(&self.fn_value, "exit"); - + fn gen_loop_block_expr( + &mut self, + block: ExprId, + exit_block: BasicBlock, + ) -> ( + BasicBlock, + Vec<(BasicValueEnum, BasicBlock)>, + Option, + ) { // Build a new loop info struct let loop_info = LoopInfo { exit_block, break_values: Vec::new(), }; + // Replace previous loop info let prev_loop = std::mem::replace(&mut self.active_loop, Some(loop_info)); - // Insert an explicit fall through from the current block to the loop - self.builder.build_unconditional_branch(&loop_block); - // Start generating code inside the loop - self.builder.position_at_end(&loop_block); - let _ = self.gen_expr(body_expr); - - // Jump to the start of the loop - self.builder.build_unconditional_branch(&loop_block); + let value = self.gen_expr(block); let LoopInfo { exit_block, break_values, } = std::mem::replace(&mut self.active_loop, prev_loop).unwrap(); + (exit_block, break_values, value) + } + + fn gen_while( + &mut self, + _expr: ExprId, + condition_expr: ExprId, + body_expr: ExprId, + ) -> Option { + let context = self.module.get_context(); + let cond_block = context.append_basic_block(&self.fn_value, "whilecond"); + let loop_block = context.append_basic_block(&self.fn_value, "while"); + let exit_block = context.append_basic_block(&self.fn_value, "afterwhile"); + + // Insert an explicit fall through from the current block to the condition check + self.builder.build_unconditional_branch(&cond_block); + + // Generate condition block + self.builder.position_at_end(&cond_block); + let condition_ir = self.gen_expr(condition_expr); + if let Some(condition_ir) = condition_ir { + self.builder.build_conditional_branch( + condition_ir.into_int_value(), + &loop_block, + &exit_block, + ); + } else { + // If the condition doesn't return a value, we also immediately return without a value. + // This can happen if the expression is a `never` expression. + return None; + } + + // Generate loop block + self.builder.position_at_end(&loop_block); + let (exit_block, _, value) = self.gen_loop_block_expr(body_expr, exit_block); + if value.is_some() { + self.builder.build_unconditional_branch(&cond_block); + } + + // Generate exit block + self.builder.position_at_end(&exit_block); + + Some(self.gen_empty()) + } + + fn gen_loop(&mut self, _expr: ExprId, body_expr: ExprId) -> Option { + let context = self.module.get_context(); + let loop_block = context.append_basic_block(&self.fn_value, "loop"); + let exit_block = context.append_basic_block(&self.fn_value, "exit"); + + // Insert an explicit fall through from the current block to the loop + self.builder.build_unconditional_branch(&loop_block); + + // Generate the body of the loop + self.builder.position_at_end(&loop_block); + let (exit_block, break_values, value) = self.gen_loop_block_expr(body_expr, exit_block); + if value.is_some() { + self.builder.build_unconditional_branch(&loop_block); + } + // Move the builder to the exit block self.builder.position_at_end(&exit_block); diff --git a/crates/mun_codegen/src/snapshots/test__while_expr.snap b/crates/mun_codegen/src/snapshots/test__while_expr.snap new file mode 100644 index 000000000..e56439f69 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__while_expr.snap @@ -0,0 +1,24 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn foo(n:int) {\n while n<3 {\n n += 1;\n };\n\n // This will be completely optimized out\n while n<4 {\n break;\n };\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +define void @foo(i64) { +body: + br label %whilecond + +whilecond: ; preds = %while, %body + %n.0 = phi i64 [ %0, %body ], [ %add, %while ] + %less = icmp slt i64 %n.0, 3 + br i1 %less, label %while, label %whilecond3 + +while: ; preds = %whilecond + %add = add i64 %n.0, 1 + br label %whilecond + +whilecond3: ; preds = %whilecond + ret void +} + diff --git a/crates/mun_codegen/src/test.rs b/crates/mun_codegen/src/test.rs index eaaed1e2c..9e1f2c61e 100644 --- a/crates/mun_codegen/src/test.rs +++ b/crates/mun_codegen/src/test.rs @@ -346,6 +346,24 @@ fn loop_break_expr() { ) } +#[test] +fn while_expr() { + test_snapshot( + r#" + fn foo(n:int) { + while n<3 { + n += 1; + }; + + // This will be completely optimized out + while n<4 { + break; + }; + } + "#, + ) +} + fn test_snapshot(text: &str) { let text = text.trim().replace("\n ", "\n"); diff --git a/crates/mun_hir/src/ty/infer.rs b/crates/mun_hir/src/ty/infer.rs index 4ef3dea94..db74dcf52 100644 --- a/crates/mun_hir/src/ty/infer.rs +++ b/crates/mun_hir/src/ty/infer.rs @@ -23,15 +23,16 @@ mod type_variable; pub use type_variable::TypeVarId; +#[macro_export] macro_rules! ty_app { ($ctor:pat, $param:pat) => { - $crate::ty::Ty::Apply($crate::ty::ApplicationTy { + $crate::Ty::Apply($crate::ApplicationTy { ctor: $ctor, parameters: $param, }) }; ($ctor:pat) => { - $crate::ty::Ty::Apply($crate::ty::ApplicationTy { + $crate::Ty::Apply($crate::ApplicationTy { ctor: $ctor, .. }) diff --git a/crates/mun_hir/src/ty/snapshots/tests__infer_while.snap b/crates/mun_hir/src/ty/snapshots/tests__infer_while.snap index dac9462eb..094af5cf7 100644 --- a/crates/mun_hir/src/ty/snapshots/tests__infer_while.snap +++ b/crates/mun_hir/src/ty/snapshots/tests__infer_while.snap @@ -1,9 +1,9 @@ --- source: crates/mun_hir/src/ty/tests.rs -expression: "fn foo() {\n let n = 0;\n while n < 3 { n += 1; };\n while n < 3 { n += 1; break; };\n while n < 3 { break 3; }; // Error: break with value in while\n while n < 3 { loop { break 3; }; };\n}" +expression: "fn foo() {\n let n = 0;\n while n < 3 { n += 1; };\n while n < 3 { n += 1; break; };\n while n < 3 { break 3; }; // error: break with value can only appear in a loop\n while n < 3 { loop { break 3; }; };\n}" --- [109; 116): `break` with value can only appear in a `loop` -[9; 200) '{ ...; }; }': nothing +[9; 217) '{ ...; }; }': nothing [19; 20) 'n': int [23; 24) '0': int [30; 53) 'while ...= 1; }': nothing @@ -29,12 +29,12 @@ expression: "fn foo() {\n let n = 0;\n while n < 3 { n += 1; };\n while [105; 106) '3': int [107; 119) '{ break 3; }': never [109; 116) 'break 3': never -[163; 197) 'while ...; }; }': nothing -[169; 170) 'n': int -[169; 174) 'n < 3': bool -[173; 174) '3': int -[175; 197) '{ loop...; }; }': nothing -[177; 194) 'loop {...k 3; }': int -[182; 194) '{ break 3; }': never -[184; 191) 'break 3': never +[180; 214) 'while ...; }; }': nothing +[186; 187) 'n': int +[186; 191) 'n < 3': bool [190; 191) '3': int +[192; 214) '{ loop...; }; }': nothing +[194; 211) 'loop {...k 3; }': int +[199; 211) '{ break 3; }': never +[201; 208) 'break 3': never +[207; 208) '3': int diff --git a/crates/mun_hir/src/ty/tests.rs b/crates/mun_hir/src/ty/tests.rs index 6fecf69d2..439f22907 100644 --- a/crates/mun_hir/src/ty/tests.rs +++ b/crates/mun_hir/src/ty/tests.rs @@ -152,7 +152,7 @@ fn infer_while() { let n = 0; while n < 3 { n += 1; }; while n < 3 { n += 1; break; }; - while n < 3 { break 3; }; // Error: break with value in while + while n < 3 { break 3; }; // error: break with value can only appear in a loop while n < 3 { loop { break 3; }; }; } "#, diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index c2573cb76..203865b90 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -241,6 +241,31 @@ fn fibonacci_loop_break() { assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); } +#[test] +fn fibonacci_while() { + let mut driver = TestDriver::new( + r#" + fn fibonacci(n:int):int { + let a = 0; + let b = 1; + let i = 1; + while i <= n { + let sum = a + b; + a = b; + b = sum; + i += 1; + } + a + } + "#, + ); + + assert_invoke_eq!(i64, 5, driver, "fibonacci", 5i64); + assert_invoke_eq!(i64, 89, driver, "fibonacci", 11i64); + assert_invoke_eq!(i64, 987, driver, "fibonacci", 16i64); + assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); +} + #[test] fn true_is_true() { let mut driver = TestDriver::new(