Skip to content

Commit

Permalink
fix: fixes issue #225
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra committed Jun 23, 2020
1 parent b738800 commit 8470812
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 50 deletions.
129 changes: 98 additions & 31 deletions crates/mun_codegen/src/ir/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
Expr::Field {
expr: receiver_expr,
name,
} => {
let ptr = self.gen_field(expr, *receiver_expr, name);
let value = self.builder.build_load(ptr, &name.to_string());
Some(value)
}
} => self.gen_field(expr, *receiver_expr, name),
_ => unimplemented!("unimplemented expr type {:?}", &body[expr]),
}
}
Expand Down Expand Up @@ -569,18 +565,20 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}
}

/// Given an expression and the type of the expression, optionally dereference the value.
fn opt_deref_value(&mut self, ty: hir::Ty, value: BasicValueEnum) -> BasicValueEnum {
match ty {
hir::Ty::Apply(hir::ApplicationTy {
ctor: hir::TypeCtor::Struct(s),
..
}) => match s.data(self.db).memory_kind {
hir::StructMemoryKind::GC => deref_heap_value(&self.builder, value),
hir::StructMemoryKind::Value => value,
},
_ => value,
/// Given an expression and its value optionally dereference the value to get to the actual
/// value. This is useful if we need to do an indirection to get to the actual value.
fn opt_deref_value(&mut self, expr: ExprId, value: BasicValueEnum) -> BasicValueEnum {
let ty = &self.infer[expr];
if let hir::Ty::Apply(hir::ApplicationTy {
ctor: hir::TypeCtor::Struct(s),
..
}) = ty
{
if s.data(self.db).memory_kind == hir::StructMemoryKind::GC {
return deref_heap_value(&self.builder, value);
}
}
value
}

/// Generates IR for looking up a certain path expression.
Expand Down Expand Up @@ -650,7 +648,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
fn gen_unary_op_float(&mut self, expr: ExprId, op: UnaryOp) -> Option<BasicValueEnum> {
let value: FloatValue = self
.gen_expr(expr)
.map(|value| self.opt_deref_value(self.infer[expr].clone(), value))
.map(|value| self.opt_deref_value(expr, value))
.expect("no value")
.into_float_value();
match op {
Expand All @@ -668,7 +666,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let value: IntValue = self
.gen_expr(expr)
.map(|value| self.opt_deref_value(self.infer[expr].clone(), value))
.map(|value| self.opt_deref_value(expr, value))
.expect("no value")
.into_int_value();
match op {
Expand All @@ -688,7 +686,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
fn gen_unary_op_bool(&mut self, expr: ExprId, op: UnaryOp) -> Option<BasicValueEnum> {
let value: IntValue = self
.gen_expr(expr)
.map(|value| self.opt_deref_value(self.infer[expr].clone(), value))
.map(|value| self.opt_deref_value(expr, value))
.expect("no value")
.into_int_value();
match op {
Expand All @@ -706,12 +704,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let lhs: IntValue = self
.gen_expr(lhs_expr)
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.map(|value| self.opt_deref_value(lhs_expr, value))
.expect("no lhs value")
.into_int_value();
let rhs: IntValue = self
.gen_expr(rhs_expr)
.map(|value| self.opt_deref_value(self.infer[rhs_expr].clone(), value))
.map(|value| self.opt_deref_value(rhs_expr, value))
.expect("no rhs value")
.into_int_value();
match op {
Expand Down Expand Up @@ -742,12 +740,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let lhs = self
.gen_expr(lhs_expr)
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.map(|value| self.opt_deref_value(lhs_expr, value))
.expect("no lhs value")
.into_float_value();
let rhs = self
.gen_expr(rhs_expr)
.map(|value| self.opt_deref_value(self.infer[rhs_expr].clone(), value))
.map(|value| self.opt_deref_value(rhs_expr, value))
.expect("no rhs value")
.into_float_value();
match op {
Expand Down Expand Up @@ -802,12 +800,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
) -> Option<BasicValueEnum> {
let lhs = self
.gen_expr(lhs_expr)
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.map(|value| self.opt_deref_value(lhs_expr, value))
.expect("no lhs value")
.into_int_value();
let rhs = self
.gen_expr(rhs_expr)
.map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value))
.map(|value| self.opt_deref_value(rhs_expr, value))
.expect("no rhs value")
.into_int_value();
match op {
Expand Down Expand Up @@ -1024,11 +1022,22 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
Expr::Field {
expr: receiver_expr,
name,
} => self.gen_field(expr, *receiver_expr, name),
} => self.gen_place_field(expr, *receiver_expr, name),
_ => unreachable!("invalid place expression"),
}
}

/// Returns true if the specified expression refers to an expression that results in a memory
/// address that can be used for other place operations.
fn is_place_expr(&self, expr: ExprId) -> bool {
let body = self.body.clone();
match &body[expr] {
Expr::Path(..) => true,
Expr::Field { expr, .. } => self.is_place_expr(*expr),
_ => false,
}
}

fn should_use_dispatch_table(&self) -> bool {
// FIXME: When we use the dispatch table, generated wrappers have infinite recursion
!self.params.make_marshallable
Expand Down Expand Up @@ -1068,7 +1077,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
// Generate IR for the condition
let condition_ir = self
.gen_expr(condition)
.map(|value| self.opt_deref_value(self.infer[condition].clone(), value))?
.map(|value| self.opt_deref_value(condition, value))?
.into_int_value();

// Generate the code blocks to branch to
Expand Down Expand Up @@ -1208,7 +1217,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
self.builder.position_at_end(&cond_block);
let condition_ir = self
.gen_expr(condition_expr)
.map(|value| self.opt_deref_value(self.infer[condition_expr].clone(), value));
.map(|value| self.opt_deref_value(condition_expr, value));
if let Some(condition_ir) = condition_ir {
self.builder.build_conditional_branch(
condition_ir.into_int_value(),
Expand Down Expand Up @@ -1264,11 +1273,69 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {
}
}

fn gen_field(&mut self, _expr: ExprId, receiver_expr: ExprId, name: &Name) -> PointerValue {
fn gen_field(
&mut self,
_expr: ExprId,
receiver_expr: ExprId,
name: &Name,
) -> Option<BasicValueEnum> {
let hir_struct = self.infer[receiver_expr]
.as_struct()
.expect("expected a struct");

let hir_struct_name = hir_struct.name(self.db);

let field_idx = hir_struct
.field(self.db, name)
.expect("expected a struct field")
.id()
.into_raw()
.into();

let field_ir_name = &format!("{}.{}", hir_struct_name, name);
if self.is_place_expr(receiver_expr) {
let receiver_ptr = self.gen_place_expr(receiver_expr);
let receiver_ptr = self
.opt_deref_value(receiver_expr, receiver_ptr.into())
.into_pointer_value();
unsafe {
let field_ptr = self.builder.build_struct_gep(
receiver_ptr,
field_idx,
&format!("{}.{}_ptr", hir_struct_name, name),
);
Some(self.builder.build_load(field_ptr, &field_ir_name))
}
} else {
let receiver_value = self.gen_expr(receiver_expr)?;
let receiver_value = self.opt_deref_value(receiver_expr, receiver_value);
let receiver_struct = receiver_value.into_struct_value();
Some(
self.builder
.build_extract_value(receiver_struct, field_idx, field_ir_name)
.ok_or_else(|| {
format!(
"could not extract field {} (index: {}) from struct {}",
name, field_idx, hir_struct_name
)
})
.unwrap(),
)
}
}

fn gen_place_field(
&mut self,
_expr: ExprId,
receiver_expr: ExprId,
name: &Name,
) -> PointerValue {
let hir_struct = self.infer[receiver_expr]
.as_struct()
.expect("expected a struct");

let hir_struct_name = hir_struct.name(self.db);

let field_idx = hir_struct
.field(self.db, name)
.expect("expected a struct field")
Expand All @@ -1278,13 +1345,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> {

let receiver_ptr = self.gen_place_expr(receiver_expr);
let receiver_ptr = self
.opt_deref_value(self.infer[receiver_expr].clone(), receiver_ptr.into())
.opt_deref_value(receiver_expr, receiver_ptr.into())
.into_pointer_value();
unsafe {
self.builder.build_struct_gep(
receiver_ptr,
field_idx,
&format!("{}.{}", hir_struct.name(self.db), name),
&format!("{}.{}_ptr", hir_struct_name, name),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ body:
store %Foo** %Foo_ptr_ptr, %Foo*** %b
%mem_ptr = load %Foo**, %Foo*** %b
%deref = load %Foo*, %Foo** %mem_ptr
%Foo.a = getelementptr inbounds %Foo, %Foo* %deref, i32 0, i32 0
%a = load i32, i32* %Foo.a
ret i32 %a
%Foo.a_ptr = getelementptr inbounds %Foo, %Foo* %deref, i32 0, i32 0
%Foo.a = load i32, i32* %Foo.a_ptr
ret i32 %Foo.a
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ source_filename = "main.mun"
define %Foo @bar_1(%Bar) {
body:
%.fca.1.0.extract = extractvalue %Bar %0, 1, 0
%"1.fca.0.insert" = insertvalue %Foo undef, i32 %.fca.1.0.extract, 0
ret %Foo %"1.fca.0.insert"
%Bar.1.fca.0.insert = insertvalue %Foo undef, i32 %.fca.1.0.extract, 0
ret %Foo %Bar.1.fca.0.insert
}

define i32 @foo_a(%Foo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ source_filename = "main.mun"

define void @foo() {
body:
%b5 = alloca %Foo**
%b = alloca %Foo**
%a = alloca %Foo**
%new_ptr = load i8** (i8*, i8*)*, i8** (i8*, i8*)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0)
%Foo_ptr = load %struct.MunTypeInfo*, %struct.MunTypeInfo** getelementptr inbounds ([5 x %struct.MunTypeInfo*], [5 x %struct.MunTypeInfo*]* @global_type_table, i32 0, i32 0)
Expand All @@ -29,15 +29,15 @@ body:
store %Foo** %Foo_ptr_ptr, %Foo*** %a
%mem_ptr = load %Foo**, %Foo*** %a
%deref = load %Foo*, %Foo** %mem_ptr
%Foo.b = getelementptr inbounds %Foo, %Foo* %deref, i32 0, i32 1
%b = load i32, i32* %Foo.b
%add = add i32 %b, 3
%Foo.b_ptr = getelementptr inbounds %Foo, %Foo* %deref, i32 0, i32 1
%Foo.b = load i32, i32* %Foo.b_ptr
%add = add i32 %Foo.b, 3
%mem_ptr1 = load %Foo**, %Foo*** %a
%deref2 = load %Foo*, %Foo** %mem_ptr1
%Foo.b3 = getelementptr inbounds %Foo, %Foo* %deref2, i32 0, i32 1
store i32 %add, i32* %Foo.b3
%Foo.b_ptr3 = getelementptr inbounds %Foo, %Foo* %deref2, i32 0, i32 1
store i32 %add, i32* %Foo.b_ptr3
%a4 = load %Foo**, %Foo*** %a
store %Foo** %a4, %Foo*** %b5
store %Foo** %a4, %Foo*** %b
ret void
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ body:
%b = alloca %Bar
%a = alloca %Foo
store %Foo { i32 5 }, %Foo* %a
%Foo.a = getelementptr inbounds %Foo, %Foo* %a, i32 0, i32 0
%a1 = load i32, i32* %Foo.a
%a2 = load %Foo, %Foo* %a
%init = insertvalue %Bar { double 1.230000e+00, i32 undef, i1 undef, %Foo undef }, i32 %a1, 1
%init3 = insertvalue %Bar %init, i1 true, 2
%init4 = insertvalue %Bar %init3, %Foo %a2, 3
store %Bar %init4, %Bar* %b
%Foo.a_ptr = getelementptr inbounds %Foo, %Foo* %a, i32 0, i32 0
%Foo.a = load i32, i32* %Foo.a_ptr
%a1 = load %Foo, %Foo* %a
%init = insertvalue %Bar { double 1.230000e+00, i32 undef, i1 undef, %Foo undef }, i32 %Foo.a, 1
%init2 = insertvalue %Bar %init, i1 true, 2
%init3 = insertvalue %Bar %init2, %Foo %a1, 3
store %Bar %init3, %Bar* %b
store %Baz undef, %Baz* %c
ret void
}
Expand Down
75 changes: 75 additions & 0 deletions crates/mun_codegen/src/snapshots/test__issue_225.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
---
source: crates/mun_codegen/src/test.rs
expression: "struct Num {\n value: i64,\n}\n\npub fn foo(b: i64) {\n Num { value: b }.value;\n}\n\npub fn bar(b: i64) {\n { let a = Num { value: b }; a}.value;\n}"
---
; == FILE IR =====================================
; ModuleID = 'main.mun'
source_filename = "main.mun"

%DispatchTable = type { i8** (i8*, i8*)* }
%struct.MunTypeInfo = type { [16 x i8], i8*, i32, i8, i8 }
%Num = type { i64 }

@allocatorHandle = external global i8*
@dispatchTable = external global %DispatchTable
@global_type_table = external global [5 x %struct.MunTypeInfo*]

define void @foo(i64) {
body:
%init = insertvalue %Num undef, i64 %0, 0
%new_ptr = load i8** (i8*, i8*)*, i8** (i8*, i8*)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0)
%Num_ptr = load %struct.MunTypeInfo*, %struct.MunTypeInfo** getelementptr inbounds ([5 x %struct.MunTypeInfo*], [5 x %struct.MunTypeInfo*]* @global_type_table, i32 0, i32 2)
%type_info_ptr_to_i8_ptr = bitcast %struct.MunTypeInfo* %Num_ptr to i8*
%allocator_handle = load i8*, i8** @allocatorHandle
%new = call i8** %new_ptr(i8* %type_info_ptr_to_i8_ptr, i8* %allocator_handle)
%Num_ptr_ptr = bitcast i8** %new to %Num**
%Num_mem_ptr = load %Num*, %Num** %Num_ptr_ptr
store %Num %init, %Num* %Num_mem_ptr
%mem_ptr = load %Num*, %Num** %Num_ptr_ptr
%deref = load %Num, %Num* %mem_ptr
ret void
}

define void @bar(i64) {
body:
%init = insertvalue %Num undef, i64 %0, 0
%new_ptr = load i8** (i8*, i8*)*, i8** (i8*, i8*)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0)
%Num_ptr = load %struct.MunTypeInfo*, %struct.MunTypeInfo** getelementptr inbounds ([5 x %struct.MunTypeInfo*], [5 x %struct.MunTypeInfo*]* @global_type_table, i32 0, i32 2)
%type_info_ptr_to_i8_ptr = bitcast %struct.MunTypeInfo* %Num_ptr to i8*
%allocator_handle = load i8*, i8** @allocatorHandle
%new = call i8** %new_ptr(i8* %type_info_ptr_to_i8_ptr, i8* %allocator_handle)
%Num_ptr_ptr = bitcast i8** %new to %Num**
%Num_mem_ptr = load %Num*, %Num** %Num_ptr_ptr
store %Num %init, %Num* %Num_mem_ptr
%mem_ptr = load %Num*, %Num** %Num_ptr_ptr
%deref = load %Num, %Num* %mem_ptr
ret void
}


; == GROUP IR ====================================
; ModuleID = 'group_name'
source_filename = "group_name"

%DispatchTable = type { i8** (i8*, i8*)* }
%struct.MunTypeInfo = type { [16 x i8], i8*, i32, i8, i8 }
%struct.MunStructInfo = type { i8**, %struct.MunTypeInfo**, i16*, i16, i8 }

@dispatchTable = global %DispatchTable zeroinitializer
@"type_info::<*const TypeInfo>::name" = private unnamed_addr constant [16 x i8] c"*const TypeInfo\00"
@"type_info::<*const TypeInfo>" = private unnamed_addr constant %struct.MunTypeInfo { [16 x i8] c"=\A1-\1F\C2\A7\88`d\90\F4\B5\BEE}x", [16 x i8]* @"type_info::<*const TypeInfo>::name", i32 64, i8 8, i8 0 }
@"type_info::<core::i64>::name" = private unnamed_addr constant [10 x i8] c"core::i64\00"
@"type_info::<core::i64>" = private unnamed_addr constant %struct.MunTypeInfo { [16 x i8] c"G\13;t\97j8\18\D7M\83`\1D\C8\19%", [10 x i8]* @"type_info::<core::i64>::name", i32 64, i8 8, i8 0 }
@"type_info::<Num>::name" = private unnamed_addr constant [4 x i8] c"Num\00"
@"struct_info::<Num>::field_names.0" = private unnamed_addr constant [6 x i8] c"value\00"
@"struct_info::<Num>::field_names" = private unnamed_addr constant [1 x i8*] [i8* @"struct_info::<Num>::field_names.0"]
@"struct_info::<Num>::field_types" = private unnamed_addr constant [1 x %struct.MunTypeInfo*] [%struct.MunTypeInfo* @"type_info::<core::i64>"]
@"struct_info::<Num>::field_offsets" = private unnamed_addr constant [1 x i16] zeroinitializer
@"type_info::<Num>" = private unnamed_addr constant { %struct.MunTypeInfo, %struct.MunStructInfo } { %struct.MunTypeInfo { [16 x i8] c"\A92\E2p\B0\98\B2\C4\0C\A2\F5=x\904\00", [4 x i8]* @"type_info::<Num>::name", i32 64, i8 8, i8 1 }, %struct.MunStructInfo { [1 x i8*]* @"struct_info::<Num>::field_names", [1 x %struct.MunTypeInfo*]* @"struct_info::<Num>::field_types", [1 x i16]* @"struct_info::<Num>::field_offsets", i16 1, i8 0 } }
@"type_info::<*const *mut core::void>::name" = private unnamed_addr constant [23 x i8] c"*const *mut core::void\00"
@"type_info::<*const *mut core::void>" = private unnamed_addr constant %struct.MunTypeInfo { [16 x i8] c"\C5fO\BD\84\DF\06\BFd+\B1\9Abv\CE\00", [23 x i8]* @"type_info::<*const *mut core::void>::name", i32 64, i8 8, i8 0 }
@"type_info::<*mut core::void>::name" = private unnamed_addr constant [16 x i8] c"*mut core::void\00"
@"type_info::<*mut core::void>" = private unnamed_addr constant %struct.MunTypeInfo { [16 x i8] c"\F0Y\22\FC\95\9E\7F\CE\08T\B1\A2\CD\A7\FAz", [16 x i8]* @"type_info::<*mut core::void>::name", i32 64, i8 8, i8 0 }
@global_type_table = constant [5 x %struct.MunTypeInfo*] [%struct.MunTypeInfo* @"type_info::<*const TypeInfo>", %struct.MunTypeInfo* @"type_info::<core::i64>", %struct.MunTypeInfo* @"type_info::<Num>", %struct.MunTypeInfo* @"type_info::<*const *mut core::void>", %struct.MunTypeInfo* @"type_info::<*mut core::void>"]
@allocatorHandle = unnamed_addr global i8* null

Loading

0 comments on commit 8470812

Please sign in to comment.