# Licensed under a 3-clause BSD style license - see PYFITS.rst import sys import numpy as np from astropy.io import fits from . import FitsTestCase def compare_arrays(arr1in, arr2in, verbose=False): """ Compare the values field-by-field in two sets of numpy arrays or recarrays. """ arr1 = arr1in.view(np.ndarray) arr2 = arr2in.view(np.ndarray) nfail = 0 for n2 in arr2.dtype.names: n1 = n2 if n1 not in arr1.dtype.names: n1 = n1.lower() if n1 not in arr1.dtype.names: n1 = n1.upper() if n1 not in arr1.dtype.names: raise ValueError(f'field name {n2} not found in array 1') if verbose: sys.stdout.write(f" testing field: '{n2}'\n") sys.stdout.write(' shape...........') if arr2[n2].shape != arr1[n1].shape: nfail += 1 if verbose: sys.stdout.write('shapes differ\n') else: if verbose: sys.stdout.write('OK\n') sys.stdout.write(' elements........') w, = np.where(arr1[n1].ravel() != arr2[n2].ravel()) if w.size > 0: nfail += 1 if verbose: sys.stdout.write( f'\n {w.size} elements in field {n2} differ\n') else: if verbose: sys.stdout.write('OK\n') if nfail == 0: if verbose: sys.stdout.write('All tests passed\n') return True else: if verbose: sys.stdout.write(f'{nfail} differences found\n') return False def get_test_data(verbose=False): st = np.zeros(3, [('f1', 'i4'), ('f2', 'S6'), ('f3', '>2f8')]) np.random.seed(35) st['f1'] = [1, 3, 5] st['f2'] = ['hello', 'world', 'byebye'] st['f3'] = np.random.random(st['f3'].shape) return st class TestStructured(FitsTestCase): def test_structured(self): fname = self.data('stddata.fits') data1, h1 = fits.getdata(fname, ext=1, header=True) data2, h2 = fits.getdata(fname, ext=2, header=True) st = get_test_data() outfile = self.temp('test.fits') fits.writeto(outfile, data1, overwrite=True) fits.append(outfile, data2) fits.append(outfile, st) assert st.dtype.isnative assert np.all(st['f1'] == [1, 3, 5]) data1check, h1check = fits.getdata(outfile, ext=1, header=True) data2check, h2check = fits.getdata(outfile, ext=2, header=True) stcheck, sthcheck = fits.getdata(outfile, ext=3, header=True) assert compare_arrays(data1, data1check, verbose=True) assert compare_arrays(data2, data2check, verbose=True) assert compare_arrays(st, stcheck, verbose=True) # try reading with view dataviewcheck, hviewcheck = fits.getdata(outfile, ext=2, header=True, view=np.ndarray) assert compare_arrays(data2, dataviewcheck, verbose=True)