From 22953241d6e01ed49aa742f0a601b911507c10b2 Mon Sep 17 00:00:00 2001 From: Rob Meng Date: Tue, 4 Feb 2025 18:53:13 -0500 Subject: [PATCH] chore: clean up reader coerce in fragment.py (#3432) --- python/python/lance/fragment.py | 26 +++----------------------- python/python/lance/types.py | 1 + 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index 495e6552d1..5289cceff8 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -24,8 +24,6 @@ import pyarrow as pa -from .dependencies import _check_for_pandas -from .dependencies import pandas as pd from .lance import ( DeletionFile as DeletionFile, ) @@ -257,7 +255,7 @@ def create_from_file( @staticmethod def create( dataset_uri: Union[str, Path], - data: Union[pa.Table, pa.RecordBatchReader], + data: ReaderLike, fragment_id: Optional[int] = None, schema: Optional[pa.Schema] = None, max_rows_per_group: int = 1024, @@ -331,16 +329,7 @@ def create( else: data_storage_version = "stable" - if _check_for_pandas(data) and isinstance(data, pd.DataFrame): - reader = pa.Table.from_pandas(data, schema=schema).to_reader() - elif isinstance(data, pa.Table): - reader = data.to_reader() - elif isinstance(data, pa.dataset.Scanner): - reader = data.to_reader() - elif isinstance(data, pa.RecordBatchReader): - reader = data - else: - raise TypeError(f"Unknown data_obj type {type(data)}") + reader = _coerce_reader(data, schema) if isinstance(dataset_uri, Path): dataset_uri = str(dataset_uri) @@ -797,16 +786,7 @@ def write_fragments( """ from .dataset import LanceDataset - if _check_for_pandas(data) and isinstance(data, pd.DataFrame): - reader = pa.Table.from_pandas(data, schema=schema).to_reader() - elif isinstance(data, pa.Table): - reader = data.to_reader() - elif isinstance(data, pa.dataset.Scanner): - reader = data.to_reader() - elif isinstance(data, pa.RecordBatchReader): - reader = data - else: - raise TypeError(f"Unknown data_obj type {type(data)}") + reader = _coerce_reader(data, schema) if isinstance(dataset_uri, Path): dataset_uri = str(dataset_uri) diff --git a/python/python/lance/types.py b/python/python/lance/types.py index b0559c5ff1..498103cb40 100644 --- a/python/python/lance/types.py +++ b/python/python/lance/types.py @@ -18,6 +18,7 @@ pa.Table, pa.dataset.Dataset, pa.dataset.Scanner, + pa.RecordBatch, Iterable[RecordBatch], pa.RecordBatchReader, ]