Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Image Curation Tutorial #254

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions nemo_curator/modules/semantic_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from nemo_curator.utils.distributed_utils import write_to_disk
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
from nemo_curator.utils.semdedup_utils import (
_assign_and_sort_clusters,
assign_and_sort_clusters,
extract_dedup_data,
get_semantic_matches_per_cluster,
)
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -123,6 +123,7 @@ def __init__(
embedding_batch_size: int,
embedding_output_dir: str,
input_column: str = "text",
embedding_column: str = "embeddings",
write_embeddings_to_disk: bool = True,
write_to_filename: bool = False,
logger: Union[logging.Logger, str] = "./",
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
self.logger = self._setup_logger(logger)
self.embedding_output_dir = embedding_output_dir
self.input_column = input_column
self.embedding_column = embedding_column
self.model = EmbeddingCrossFitModel(self.embeddings_config)
self.write_embeddings_to_disk = write_embeddings_to_disk
self.write_to_filename = write_to_filename
Expand Down Expand Up @@ -190,7 +192,7 @@ def create_embeddings(
self.model,
sorted_data_loader=True,
batch_size=self.batch_size,
pred_output_col="embeddings",
pred_output_col=self.embedding_column,
),
keep_cols=ddf.columns.tolist(),
)
Expand All @@ -215,12 +217,14 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:


### Clustering Module
def get_embedding_ar(df: "cudf.DataFrame") -> cp.ndarray:
return df["embeddings"].list.leaves.values.reshape(len(df), -1)
def get_embedding_ar(df: "cudf.DataFrame", embedding_col: str) -> cp.ndarray:
return df[embedding_col].list.leaves.values.reshape(len(df), -1)


def add_dist_to_cents(df: "cudf.DataFrame", centroids: cp.ndarray) -> "cudf.DataFrame":
embed_array = get_embedding_ar(df)
def add_dist_to_cents(
df: "cudf.DataFrame", embedding_col: str, centroids: cp.ndarray
) -> "cudf.DataFrame":
embed_array = get_embedding_ar(df, embedding_col)
centroids_ar = centroids[df["nearest_cent"].values]
dist_to_cents = cp.sqrt(np.sum((embed_array - centroids_ar) ** 2, axis=1))
df["dist_to_cent"] = dist_to_cents
Expand All @@ -234,6 +238,7 @@ def __init__(
max_iter: int,
n_clusters: int,
clustering_output_dir: str,
embedding_col: str = "embeddings",
sim_metric: str = "cosine",
which_to_keep: str = "hard",
sort_clusters: bool = True,
Expand All @@ -249,6 +254,7 @@ def __init__(
max_iter (int): Maximum number of iterations for the clustering algorithm.
n_clusters (int): The number of clusters to form.
clustering_output_dir (str): Directory path where clustering results will be saved.
embedding_col (str): Column name where the embeddings are stored.
sim_metric (str): Similarity metric to use for clustering, default is "cosine".
which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
sort_clusters (bool): Whether to sort clusters, default is True.
Expand All @@ -262,6 +268,7 @@ def __init__(
self.max_iter = max_iter
self.n_clusters = n_clusters
self.clustering_output_dir = clustering_output_dir
self.embedding_col = embedding_col
self.sim_metric = sim_metric
self.keep_hard = which_to_keep == "hard"
self.kmeans_with_cos_dist = kmeans_with_cos_dist
Expand Down Expand Up @@ -291,15 +298,20 @@ def _setup_logger(self, logger):
def __call__(self, embeddings_dataset: DocumentDataset):
embeddings_df = embeddings_dataset.df

assert "embeddings" in embeddings_df.columns
embeddings_df = embeddings_df[[self.id_col, "embeddings"]]
if self.embedding_col not in embeddings_df.columns:
raise ValueError(
f"Expected embedding column '{self.embedding_col}'"
f" to be in dataset. Only found columns {embeddings_df.columns}"
)

embeddings_df = embeddings_df[[self.id_col, self.embedding_col]]

embeddings_df = embeddings_df.to_backend("pandas").persist()
embeddings_df = embeddings_df.repartition(partition_size=self.partition_size)
embeddings_df = embeddings_df.to_backend("cudf")

cupy_darr = embeddings_df.map_partitions(
get_embedding_ar, meta=cp.ndarray([1, 1])
get_embedding_ar, self.embedding_col, meta=cp.ndarray([1, 1])
)
cupy_darr.compute_chunk_sizes()

Expand All @@ -317,7 +329,10 @@ def __call__(self, embeddings_dataset: DocumentDataset):
meta_df = embeddings_df._meta.copy()
meta_df["dist_to_cent"] = cp.zeros(1)
embeddings_df = embeddings_df.map_partitions(
add_dist_to_cents, centroids=kmeans.cluster_centers_, meta=meta_df
add_dist_to_cents,
embedding_col=self.embedding_col,
centroids=kmeans.cluster_centers_,
meta=meta_df,
)
centroids = kmeans.cluster_centers_
embeddings_df = embeddings_df.reset_index(drop=True)
Expand Down Expand Up @@ -348,13 +363,14 @@ def __call__(self, embeddings_dataset: DocumentDataset):
del embeddings_df

if self.sort_clusters:
_assign_and_sort_clusters(
assign_and_sort_clusters(
id_col=self.id_col,
kmeans_centroids_file=kmeans_centroids_file,
nearest_cent_dir=clustering_output_dir,
output_sorted_clusters_dir=os.path.join(
self.clustering_output_dir, "sorted"
),
embedding_col=self.embedding_col,
sim_metric=self.sim_metric,
keep_hard=self.keep_hard,
kmeans_with_cos_dist=self.kmeans_with_cos_dist,
Expand All @@ -380,6 +396,7 @@ def __init__(
id_col_type: str,
which_to_keep: str,
output_dir: str,
embedding_col: str = "embeddings",
logger: Union[logging.Logger, str] = "./",
) -> None:
"""
Expand All @@ -393,6 +410,7 @@ def __init__(
id_col_type (str): Data type of the ID column.
which_to_keep (str): Strategy for which duplicate to keep.
output_dir (str): Directory to save output files.
embedding_col (str): Column where the embeddings are stored.
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
"""
self.n_clusters = n_clusters
Expand All @@ -406,6 +424,7 @@ def __init__(
output_dir, "semdedup_pruning_tables"
)
self.computed_semantic_match_dfs = False
self.embedding_col = embedding_col
self.logger = self._setup_logger(logger)

def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger:
Expand Down Expand Up @@ -461,6 +480,7 @@ def compute_semantic_match_dfs(
id_col_type=self.id_col_type,
eps_list=eps_list,
output_dir=self.semdedup_pruning_tables_dir,
embedding_col=self.embedding_col,
which_to_keep=self.which_to_keep,
)
)
Expand Down
26 changes: 18 additions & 8 deletions nemo_curator/utils/semdedup_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir


def _assign_and_sort_clusters(
def assign_and_sort_clusters(
id_col: str,
kmeans_centroids_file: str,
nearest_cent_dir: str,
output_sorted_clusters_dir: str,
cluster_ids=List[int],
cluster_ids: List[int],
embedding_col: str,
sim_metric: str = "cosine",
keep_hard: bool = True,
kmeans_with_cos_dist: bool = True,
Expand Down Expand Up @@ -78,6 +79,7 @@ def _assign_and_sort_clusters(
nearest_cent_dir=nearest_cent_dir,
output_sorted_clusters_dir=output_sorted_clusters_dir,
centroids=kmeans_centroids,
embedding_col=embedding_col,
sim_metric=sim_metric,
keep_hard=keep_hard,
kmeans_with_cos_dist=kmeans_with_cos_dist,
Expand All @@ -98,6 +100,7 @@ def rank_within_cluster(
nearest_cent_dir: str,
output_sorted_clusters_dir: str,
centroids: np.ndarray,
embedding_col: str,
sim_metric: str = "cosine",
keep_hard: bool = True,
kmeans_with_cos_dist: bool = False,
Expand Down Expand Up @@ -131,10 +134,10 @@ def rank_within_cluster(
continue

cluster_df = cudf.read_parquet(
cluster_c_path, columns=[id_col, "dist_to_cent", "embeddings"]
cluster_c_path, columns=[id_col, "dist_to_cent", embedding_col]
)
embeds = torch.as_tensor(
cluster_df["embeddings"].list.leaves.values.reshape(
cluster_df[embedding_col].list.leaves.values.reshape(
cluster_df.shape[0], -1
),
device="cuda",
Expand Down Expand Up @@ -188,11 +191,15 @@ def _semdedup(


def get_cluster_reps(
cluster_id: int, emb_by_clust_dir: str, id_col: str, sorted_ids: np.ndarray
cluster_id: int,
emb_by_clust_dir: str,
id_col: str,
embedding_col: str,
sorted_ids: np.ndarray,
) -> torch.Tensor:
cluster_i_path = os.path.join(emb_by_clust_dir, f"nearest_cent={cluster_id}")
cluster_reps = cudf.read_parquet(
cluster_i_path, columns=["embeddings", id_col]
cluster_i_path, columns=[embedding_col, id_col]
).sort_values(by=id_col)
num = cluster_reps.shape[0]

Expand All @@ -203,7 +210,7 @@ def get_cluster_reps(
cluster_reps = cluster_reps.sort_values(by="inverse_sort_id")

cluster_reps = torch.as_tensor(
cluster_reps["embeddings"].list.leaves.values.reshape(len(cluster_reps), -1),
cluster_reps[embedding_col].list.leaves.values.reshape(len(cluster_reps), -1),
device="cuda",
)
return cluster_reps
Expand All @@ -217,6 +224,7 @@ def get_semantic_matches_per_cluster(
id_col_type: str,
eps_list: List[float],
output_dir: str,
embedding_col: str,
which_to_keep: str,
) -> None:

Expand Down Expand Up @@ -251,7 +259,9 @@ def get_semantic_matches_per_cluster(

text_ids = cluster_i[:, 0].astype(id_col_type)

cluster_reps = get_cluster_reps(cluster_id, emb_by_clust_dir, id_col, text_ids)
cluster_reps = get_cluster_reps(
cluster_id, emb_by_clust_dir, id_col, embedding_col, text_ids
)
M, M1 = _semdedup(cluster_reps, "cuda")
assert cluster_reps.shape[0] == len(text_ids)

Expand Down
Loading