Skip to content

Commit

Permalink
Start work on LICM
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjoseph1995 committed Jan 2, 2025
1 parent ab6a360 commit 9f90948
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 14 deletions.
5 changes: 3 additions & 2 deletions common/src/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl NodeEntry for BasicBlock {
pub struct Dominators<'a> {
cfg: &'a Cfg,
// The dominator set for each node
// If an element is present in set set_per_node[i] then it is a dominator of node i
pub set_per_node: Vec<HashSet<NodeIndex>>,
}

Expand Down Expand Up @@ -109,7 +110,7 @@ impl<Data: NodeEntry> DirectedGraph<Data> {
let node_name = self.get_node_name(index);
let node_text = self.nodes[index].data.get_textual_representation();
statements.push(format!(
"\"{node_name}\" [shape=record, label=\"{node_name} | {node_text}\"]",
"\"{node_name}\" [shape=record, label=\"{node_name} \\| idx={index} | {node_text}\"]",
));
// Add more information for each node
}
Expand Down Expand Up @@ -356,7 +357,7 @@ impl<'a> Dominators<'a> {
}
}

fn convert_cfg_to_instruction_stream(cfg: Cfg) -> Vec<Code> {
pub fn convert_cfg_to_instruction_stream(cfg: Cfg) -> Vec<Code> {
cfg.dag
.nodes
.into_iter()
Expand Down
48 changes: 37 additions & 11 deletions dataflow_analysis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ enum Direction {
}

trait Analysis<'a, ValueType: Clone + Hash + Eq + Display> {
fn run(&self, cfg: &'a Cfg, init: HashSet<ValueType>, direction: Direction) -> () {
fn run(
&self,
cfg: &'a Cfg,
init: HashSet<ValueType>,
direction: Direction,
display: Option<bool>,
) -> (Vec<HashSet<ValueType>>, Vec<HashSet<ValueType>>) {
let display = display.unwrap_or(false);
let all_predecessors: Vec<&[usize]> = (0..cfg.dag.number_of_nodes())
.map(|node_index| cfg.dag.get_predecessor_indices(node_index))
.collect();
Expand Down Expand Up @@ -57,7 +64,10 @@ trait Analysis<'a, ValueType: Clone + Hash + Eq + Display> {
worklist.extend(output_edges[node_index]);
}
}
self.display(cfg, &input_list, &output_list);
if display {
self.display(cfg, &input_list, &output_list);
}
(input_list, output_list)
}

fn display(
Expand Down Expand Up @@ -127,11 +137,21 @@ trait Analysis<'a, ValueType: Clone + Hash + Eq + Display> {
) -> HashSet<ValueType>;
}

struct LiveVariableAnalysis {}
pub struct LiveVariableAnalysis {}

impl LiveVariableAnalysis {
pub fn run_analysis<'a>(
&self,
cfg: &'a Cfg,
display: Option<bool>,
) -> (Vec<HashSet<&'a str>>, Vec<HashSet<&'a str>>) {
self.run(cfg, HashSet::new(), Direction::Backward, display)
}
}

#[derive(Derivative)]
#[derivative(Eq, PartialEq, Hash)]
struct Definition<'a> {
pub struct Definition<'a> {
destination_variable: &'a str,
basic_block_index: usize,
instruction_index: usize,
Expand All @@ -152,7 +172,17 @@ impl Clone for Definition<'_> {
}
}

struct ReachingDefinitions {}
pub struct ReachingDefinitions {}

impl ReachingDefinitions {
pub fn run_analysis<'a>(
&self,
cfg: &'a Cfg,
display: Option<bool>,
) -> (Vec<HashSet<Definition<'a>>>, Vec<HashSet<Definition<'a>>>) {
self.run(cfg, HashSet::new(), Direction::Forward, display)
}
}

impl<'a> Analysis<'a, &'a str> for LiveVariableAnalysis {
fn merge(
Expand Down Expand Up @@ -313,14 +343,10 @@ pub fn run_analysis(dataflow_analysis_name: DataflowAnalyses, program: &Program)
.map(|f| (f, Cfg::new(f)))
.for_each(|(f, cfg)| match dataflow_analysis_name {
DataflowAnalyses::LiveVariable => {
LiveVariableAnalysis {}.run(&cfg, HashSet::new(), Direction::Backward);
let _ = LiveVariableAnalysis {}.run_analysis(&cfg, Some(true));
}
DataflowAnalyses::ReachingDefinitions => {
ReachingDefinitions {}.run(
&cfg,
create_set_of_definitions_from_function_arguments(&cfg, &f.args),
Direction::Forward,
);
let _ = ReachingDefinitions {}.run_analysis(&cfg, Some(true));
}
});
}
5 changes: 4 additions & 1 deletion driver/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ struct Args {
#[arg(short, long, value_enum, help = "Type of dataflow analysis to run")]
dataflow_analysis: Option<DataflowAnalyses>,

#[arg(long, help = "Dump the AST as a DOT file")]
#[arg(
long,
help = "Dump the AST of each function in the program as DOT/Graphviz format"
)]
dump_ast_as_dot: bool,

#[arg(long, help = "Output the program after optimizations")]
Expand Down
1 change: 1 addition & 0 deletions optimizations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
brilirs = { version = "0.1.0", path = "../bril/brilirs" }
clap = "4.5.20"
common = { version = "0.1.0", path = "../common" }
dataflow_analysis = { version = "0.1.0", path = "../dataflow_analysis" }
indoc = "2.0.5"
smallstr = "0.3.0"
smallvec = "1.13.2"
Expand Down
5 changes: 5 additions & 0 deletions optimizations/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod local_dead_code_elimination;
mod local_value_numbering;
mod loop_invariant_code_motion;

use std::vec;

Expand All @@ -10,6 +11,7 @@ use common::BasicBlock;
pub enum OptimizationPass {
LocalDeadCodeElimination,
LocalValueNumbering,
LoopInvariantCodeMotion,
}

pub struct PassManager {
Expand All @@ -25,6 +27,9 @@ impl PassManager {
OptimizationPass::LocalValueNumbering => {
Box::new(local_value_numbering::LocalValueNumberingPass::new())
}
OptimizationPass::LoopInvariantCodeMotion => {
Box::new(loop_invariant_code_motion::LoopInvariantCodeMotionPass::new())
}
}
}
pub fn new() -> Self {
Expand Down
151 changes: 151 additions & 0 deletions optimizations/src/loop_invariant_code_motion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use crate::Pass;
use common::cfg::{self, Cfg, Dominators};
use dataflow_analysis::{run_analysis, DataflowAnalyses, ReachingDefinitions};
use std::collections::HashSet;

pub struct LoopInvariantCodeMotionPass {}

impl Pass for LoopInvariantCodeMotionPass {
fn apply(&mut self, mut program: bril_rs::Program) -> bril_rs::Program {
let mut output_program = cfg::convert_to_ssa(program);
for function in output_program.functions.iter_mut() {
function.instrs = common::cfg::convert_cfg_to_instruction_stream(
self.process_cfg(Cfg::new(function)),
);
}
output_program
}
}

fn find_back_edges(
cfg: &Cfg,
dominators: &Dominators,
) -> HashSet<(usize /*Start index*/, usize /*End index*/)> {
if cfg.dag.number_of_nodes() == 0 {
return HashSet::new();
}
let mut back_edges = HashSet::new();
let mut visited = vec![false; cfg.dag.number_of_nodes()];
let mut nodes_to_visit = vec![0];
while !nodes_to_visit.is_empty() {
let node = nodes_to_visit.pop().unwrap();
visited[node] = true;
for &successor in cfg.dag.get_successor_indices(node) {
// If the successor is visited and is a dominator of the current node, then it is a back edge.
if visited[successor] && dominators.set_per_node[node].contains(&successor) {
back_edges.insert((node, successor));
} else {
nodes_to_visit.push(successor);
}
}
}
back_edges
}

fn find_loop_nodes(
cfg: &Cfg,
dominators: &Dominators,
loop_header: usize,
seed: usize,
) -> Vec<usize> {
let mut loop_nodes = vec![loop_header];
let mut visited = vec![false; cfg.dag.number_of_nodes()];
visited[loop_header] = true;
let mut nodes_to_visit = vec![seed];
while !nodes_to_visit.is_empty() {
let node = nodes_to_visit.pop().unwrap();
visited[node] = true;
loop_nodes.push(node);
for &predecessor in cfg.dag.get_predecessor_indices(node) {
// If the predecessor is not visited and if the loop header dominates the predecessor, then add it to the list of nodes to visit.
// All nodes in the loop should have the loop header as a dominator.
if !visited[predecessor] && dominators.set_per_node[predecessor].contains(&loop_header)
{
nodes_to_visit.push(predecessor);
}
}
}
loop_nodes
}

impl LoopInvariantCodeMotionPass {
pub fn new() -> Self {
LoopInvariantCodeMotionPass {}
}
fn process_cfg(&mut self, cfg: Cfg) -> Cfg {
// Precondition: We make the assumption that the CFG is reducible.
let dominators = cfg::Dominators::new(&cfg);
let (reaching_definitions_in, reaching_definitions_out) =
ReachingDefinitions {}.run_analysis(&cfg, Some(false));
find_back_edges(&cfg, &dominators)
.iter()
.map(|(src, loop_header)| {
// For each back edge, find the loop nodes.
assert!(
dominators.set_per_node[*src].contains(&loop_header),
"{src}->{loop_header}| dominators of {} are: {:#?}",
loop_header,
dominators.set_per_node[*loop_header]
);
(
*loop_header,
find_loop_nodes(&cfg, &dominators, *loop_header, *src),
)
})
.for_each(
|(loop_header, loop_nodes)| // For each loop, find the loop invariant instructions.
{
// todo!("Find loop invariant instructions for loop: {:#?}", loop_nodes);
},
);

cfg
}
}

#[cfg(test)]
mod tests {
use crate::Pass;
use bril_rs::Program;

fn parse_program(text: &str) -> Program {
let program = common::parse_bril_text(text);
assert!(program.is_ok(), "{}", program.err().unwrap());
program.unwrap()
}

#[test]
fn test_loop_invariant_code_motion() {
let program = parse_program(indoc::indoc! {r#"
@main {
n: int = const 10;
inc: int = const 5;
one: int = const 1;
invariant: int = const 100;
i: int = const 0;
sum: int = const 0;
.loop:
cond: bool = lt i n;
br cond .body .done;
.body:
temp: int = add invariant inc;
sum: int = add sum temp;
i: int = add i one;
body_cond: bool = lt temp sum;
br body_cond .body_left .body_right;
.body_left:
jmp .body_join;
.body_right:
dead_store: int = const 0;
jmp .body_join;
.body_join:
jmp .loop;
.done:
print sum;
ret;
}
"#});
let optimized_program = super::LoopInvariantCodeMotionPass::new().apply(program);
println!("{}", optimized_program);
}
}

0 comments on commit 9f90948

Please sign in to comment.