Skip to content

Commit

Permalink
Merge pull request #81 from fact-project/add_sample_fraction
Browse files Browse the repository at this point in the history
Write sample_fraction into hdf files in split_data
  • Loading branch information
maxnoe authored May 16, 2019
2 parents 572d7ea + 8d079b1 commit 4cca484
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 6 deletions.
8 changes: 8 additions & 0 deletions aict_tools/scripts/fact_to_dl3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
from functools import partial
import os
import h5py

from astropy.time import Time
from astropy.coordinates import AltAz, SkyCoord
Expand Down Expand Up @@ -369,6 +370,13 @@ def main(
else:
to_h5py(df[dl3_columns_sim], output, key='events', mode='a')

with h5py.File(data_path, 'r') as f:
sample_fraction = f.attrs.get('sample_fraction')

if sample_fraction is not None:
with h5py.File(output, 'r+') as f:
f.attrs['sample_fraction'] = sample_fraction

if source:
log.info('Copying "runs" group')
to_h5py(runs, output, key='runs', mode='a')
Expand Down
13 changes: 13 additions & 0 deletions aict_tools/scripts/split_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from math import ceil
import h5py
from tqdm import tqdm
import h5py

log = logging.getLogger()

Expand Down Expand Up @@ -117,6 +118,10 @@ def split_multi_telescope_data(input_path, output_basename, fraction, name):
write_data(selected_runs, path, key='runs', use_h5py=True, mode='w')
write_data(selected_array_events, path, key='array_events', use_h5py=True, mode='a')
write_data(selected_telescope_events, path, key='telescope_events', use_h5py=True, mode='a')

with h5py.File(path, 'r+') as f:
f.attrs['sample_fraction'] = n / n_total

log.debug(f'selected runs {set(selected_run_ids)}')
log.debug(f'Runs minus selected runs {ids - set(selected_run_ids)}')
ids = ids - set(selected_run_ids)
Expand Down Expand Up @@ -160,6 +165,11 @@ def split_single_telescope_data_chunked(input_path, output_basename, inkey, key,
))
write_data(selected_data, path, key=key, use_h5py=True, mode=mode)

for n, part_name in zip(num_ids, name):
path = output_basename + '_' + part_name + '.hdf5'
with h5py.File(path, mode='r+') as f:
f.attrs['sample_fraction'] = n / n_total


def split_single_telescope_data(input_path, output_basename, fmt, inkey, key, fraction, name):

Expand All @@ -184,6 +194,9 @@ def split_single_telescope_data(input_path, output_basename, fmt, inkey, key, fr
log.info('Writing {} telescope-array events to: {}'.format(n, path))
write_data(selected_data, path, key=key, use_h5py=True, mode='w')

with h5py.File(path, mode='r+') as f:
f.attrs['sample_fraction'] = n / n_total

elif fmt == 'csv':
filename = output_basename + '_' + part_name + '.csv'
log.info('Writing {} telescope-array events to: {}'.format(n, filename))
Expand Down
Binary file modified examples/gamma.hdf5
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name='aict_tools',
version='0.15.0',
version='0.16.0',
description='Artificial Intelligence for Imaging Atmospheric Cherenkov Telescopes',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
27 changes: 22 additions & 5 deletions tests/test_executables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from click.testing import CliRunner
import shutil
from traceback import print_exception
import h5py


def test_train_regressor():
Expand Down Expand Up @@ -266,6 +267,7 @@ def test_to_dl3():
print_exception(*result.exc_info)
assert result.exit_code == 0

output = os.path.join(d, 'gamma_dl3.hdf5')
result = runner.invoke(
to_dl3,
[
Expand All @@ -275,7 +277,7 @@ def test_to_dl3():
os.path.join(d, 'regressor.pkl'),
os.path.join(d, 'disp.pkl'),
os.path.join(d, 'sign.pkl'),
os.path.join(d, 'gamma_dl3.hdf5'),
output,
]
)

Expand All @@ -284,6 +286,9 @@ def test_to_dl3():
print_exception(*result.exc_info)
assert result.exit_code == 0

with h5py.File(output) as f:
assert f.attrs['sample_fraction'] == 1000 / 1851297


def test_split_data_executable():
from aict_tools.scripts.split_data import main as split
Expand All @@ -299,9 +304,9 @@ def test_split_data_executable():
os.path.join(d, 'gamma.hdf5'),
os.path.join(d, 'signal'),
'-ntest', # no spaces here. maybe a bug in click?
'-f0.5',
'-f0.75',
'-ntrain',
'-f0.5',
'-f0.25',
]
)
if result.exit_code != 0:
Expand All @@ -313,9 +318,15 @@ def test_split_data_executable():
test_path = os.path.join(d, 'signal_test.hdf5')
assert os.path.isfile(test_path)

with h5py.File(test_path, 'r') as f:
assert f.attrs['sample_fraction'] == 0.75

train_path = os.path.join(d, 'signal_train.hdf5')
assert os.path.isfile(train_path)

with h5py.File(train_path, 'r') as f:
assert f.attrs['sample_fraction'] == 0.25


def test_split_data_executable_chunked():
from aict_tools.scripts.split_data import main as split
Expand All @@ -331,9 +342,9 @@ def test_split_data_executable_chunked():
os.path.join(d, 'gamma.hdf5'),
os.path.join(d, 'signal'),
'-ntest', # no spaces here. maybe a bug in click?
'-f0.5',
'-f0.75',
'-ntrain',
'-f0.5',
'-f0.25',
'--chunksize=100',
]
)
Expand All @@ -346,5 +357,11 @@ def test_split_data_executable_chunked():
test_path = os.path.join(d, 'signal_test.hdf5')
assert os.path.isfile(test_path)

with h5py.File(test_path, 'r') as f:
assert f.attrs['sample_fraction'] == 0.75

train_path = os.path.join(d, 'signal_train.hdf5')
assert os.path.isfile(train_path)

with h5py.File(train_path, 'r') as f:
assert f.attrs['sample_fraction'] == 0.25

0 comments on commit 4cca484

Please sign in to comment.