diff --git a/docs/apidocs/docs/index.md b/docs/apidocs/docs/index.md index db77db6..6d339c4 100644 --- a/docs/apidocs/docs/index.md +++ b/docs/apidocs/docs/index.md @@ -1,7 +1,12 @@ # Hydra API documentation -- [Example simulations (`example`)](example.md) -- [Point source sampler (`ptsrc_sampler`)](ptsrc_sampler.md) +## Samplers (heads) of Hydra - [Diffuse emission region sampler (`region_sampler`)](region_sampler.md) +- [Point source sampler (`ptsrc_sampler`)](ptsrc_sampler.md) - [Spherical harmonic sampler (`sh_sampler`)](sh_sampler.md) +## Utility and simulation functions +- [Example simulations (`example`)](example.md) +- [Linear solvers (`linear_solver`)](linear_solver.md) +- [Utility functions (`utils`)](utils.md) +- [Visibility simulators (`vis_simulator`)](vis_simulator.md) diff --git a/docs/apidocs/docs/linear_solver.md b/docs/apidocs/docs/linear_solver.md new file mode 100644 index 0000000..c2c3545 --- /dev/null +++ b/docs/apidocs/docs/linear_solver.md @@ -0,0 +1,6 @@ +# Linear solvers (`linear_solver`) + +::: hydra.linear_solver + options: + show_root_heading: false + diff --git a/docs/apidocs/docs/utils.md b/docs/apidocs/docs/utils.md new file mode 100644 index 0000000..2e179c5 --- /dev/null +++ b/docs/apidocs/docs/utils.md @@ -0,0 +1,6 @@ +# Utility functions (`utils`) + +::: hydra.utils + options: + show_root_heading: false + diff --git a/docs/apidocs/docs/vis_simulator.md b/docs/apidocs/docs/vis_simulator.md new file mode 100644 index 0000000..1d94c79 --- /dev/null +++ b/docs/apidocs/docs/vis_simulator.md @@ -0,0 +1,6 @@ +# Visibility simulators (`vis_simulator`) + +::: hydra.vis_simulator + options: + show_root_heading: false + diff --git a/example.py b/example.py index 7d60d3c..c263f38 100644 --- a/example.py +++ b/example.py @@ -89,12 +89,13 @@ sim_gain_amp_std = args.sim_gain_amp_std # Source position and LST/frequency ranges - #ra_low, ra_high = (min(args.ra_bounds), max(args.ra_bounds)) - #dec_low, dec_high = (min(args.dec_bounds), max(args.dec_bounds)) - lst_min, lst_max = (min(args.lst_bounds), max(args.lst_bounds)) + # LSTs specified in hours, but converted to radians + # Freqs. specified in MHz + lst_min, lst_max = (min(args.lst_bounds) * 2.*np.pi/24., + max(args.lst_bounds) * 2.*np.pi/24.) freq_min, freq_max = (min(args.freq_bounds), max(args.freq_bounds)) - # Array latitude + # Array latitude, in degrees array_latitude = np.deg2rad(args.latitude) #-------------------------------------------------------------------------- @@ -181,11 +182,14 @@ ptsrc_amps = np.zeros_like(ra) if myid == 0: - # Generate random catalogue - ra, dec, ptsrc_amps = generate_random_ptsrc_catalogue(Nptsrc, - ra_bounds=args.ra_bounds, - dec_bounds=args.dec_bounds, - logflux_bounds=(-1., 2.)) + # Generate random catalogue (convert input ra/dec from deg to rad) + # Returned ra and dec arrays are now in radians + ra, dec, ptsrc_amps = generate_random_ptsrc_catalogue( + Nptsrc, + ra_bounds=np.deg2rad(args.ra_bounds), + dec_bounds=np.deg2rad(args.dec_bounds), + logflux_bounds=(-1., 2.) + ) # Save generated catalogue info np.save(os.path.join(output_dir, "ptsrc_amps0"), ptsrc_amps) np.save(os.path.join(output_dir, "ptsrc_coords0"), np.column_stack((ra, dec)).T) @@ -294,20 +298,21 @@ calsrc = True calsrc_std = args.calsrc_std - # Select what would be the calibration source (brightest, close to beam) - calsrc_idxs = np.where(np.abs(dec - array_latitude)*180./np.pi < calsrc_radius)[0] - assert len(calsrc_idxs) > 0, "No sources found within %d deg of the zenith" % calsrc_radius - calsrc_idx = calsrc_idxs[np.argmax(ptsrc_amps[calsrc_idxs])] - calsrc_amp = ptsrc_amps[calsrc_idx] - if myid == 0: - print("Calibration source:") - print(" Enabled: %s" % calsrc) - print(" Index: %d" % calsrc_idx) - print(" Amplitude: %6.3e" % calsrc_amp) - print(" Dist. from zenith: %6.2f deg" \ - % np.rad2deg(np.abs(dec[calsrc_idx] - array_latitude))) - print(" Flux @ lowest freq: %6.3e Jy" % fluxes_chunk[calsrc_idx,0]) - print("") + # Select what would be the calibration source (brightest, close to beam) + calsrc_idxs = np.where(np.abs(dec - array_latitude)*180./np.pi < calsrc_radius)[0] + assert len(calsrc_idxs) > 0, \ + "No sources found within %d deg of the zenith" % calsrc_radius + calsrc_idx = calsrc_idxs[np.argmax(ptsrc_amps[calsrc_idxs])] + calsrc_amp = ptsrc_amps[calsrc_idx] + if myid == 0: + print("Calibration source:") + print(" Enabled: %s" % calsrc) + print(" Index: %d" % calsrc_idx) + print(" Amplitude: %6.3e" % calsrc_amp) + print(" Dist. from zenith: %6.2f deg" \ + % np.rad2deg(np.abs(dec[calsrc_idx] - array_latitude))) + print(" Flux @ lowest freq: %6.3e Jy" % fluxes_chunk[calsrc_idx,0]) + print("") #-------------------------------------------------------------------------- @@ -446,7 +451,7 @@ if myid == 0: status(None, "Ptsrc amp. prior level: %s" % ptsrc_amp_prior_level, colour='b') if calsrc: - amp_prior_std[calsrc_idx] = calsrc_std + ptsrc_amp_prior_std[calsrc_idx] = calsrc_std # Precompute gain perturbation projection operators A_real, A_imag = None, None @@ -955,8 +960,10 @@ comm.Bcast(x_soln, root=0) comm.barrier() if myid == 0: - status(myid, " Example ptsrc soln:" + str(x_soln[:3])) - status(myid, " Example region soln:" + str(x_soln[Nptsrc:Nptsrc+3])) + status(None, " Example ptsrc soln: " + str(x_soln[:3])) + status(None, " Example region soln: " + str(x_soln[Nptsrc:Nptsrc+3])) + status(None, " Ptsrc soln. avg.: %+8.6f +/- %8.6f" \ + % (np.mean(x_soln[:3]), np.std(x_soln[:3]))) # Update visibility model with latest solution (does not include any gains) # Applies projection operator to ptsrc amplitude vector diff --git a/hydra/__init__.py b/hydra/__init__.py index f6b1f50..7f30ae4 100644 --- a/hydra/__init__.py +++ b/hydra/__init__.py @@ -7,5 +7,5 @@ sh_sampler, vis_sampler, ) -from . import config, example, linear_solver, plot, sparse_beam, utils, vis_simulator +from . import config, example, io, linear_solver, plot, sparse_beam, utils, vis_simulator from .utils import * diff --git a/hydra/config.py b/hydra/config.py index 406f6fa..d1a8374 100644 --- a/hydra/config.py +++ b/hydra/config.py @@ -222,31 +222,31 @@ def get_config(): "--ra-bounds", type=float, action="store", - default=(0, 1), + default=(0.0, 60.), nargs=2, required=False, dest="ra_bounds", - help="Bounds for the Right Ascension of the randomly simulated sources", + help="Bounds for the Right Ascension of the randomly simulated sources, in degrees.", ) parser.add_argument( "--dec-bounds", type=float, action="store", - default=(-0.6, 0.4), + default=(-40.7, -20.7), nargs=2, required=False, dest="dec_bounds", - help="Bounds for the Declination of the randomly simulated sources", + help="Bounds for the Declination of the randomly simulated sources, in degrees.", ) parser.add_argument( "--lst-bounds", type=float, action="store", - default=(0.2, 0.5), + default=(0.75, 1.9), nargs=2, required=False, dest="lst_bounds", - help="Bounds for the LST range of the simulation, in radians.", + help="Bounds for the LST range of the simulation, in hours.", ) parser.add_argument( "--freq-bounds", diff --git a/hydra/io.py b/hydra/io.py new file mode 100644 index 0000000..57d81f7 --- /dev/null +++ b/hydra/io.py @@ -0,0 +1,209 @@ + +import numpy as np +from pyuvdata import UVData + + +def load_uvdata_metadata(comm, fname): + """ + Load metadata from a UVData-compatible file and distribute + it to all MPI workers. + + Parameters: + comm (MPI Communicator): + Optional MPI communicator. The root node will load the metadata + and broadcast it to all other workers. + fname (str): + Path to the data file, which should be a UVH5 file that supports + partial loading. + + Returns: + data_info (dict): + Dictionary with several named properties of the data. + """ + myid = 0 + if comm is not None: + myid = comm.Get_rank() + + # Root worker to load metadata and distribute it + if myid == 0: + uvd = UVData() + uvd.read(fname, read_data=False) # metadata only + + # Get frequency and LST arrays + freqs = np.unique(uvd.freq_array) / 1e6 # MHz + lsts = np.unique(uvd.lst_array) + + # Get array latitude etc. + lat, lon, alt = uvd.telescope_location_lat_lon_alt_degrees + + # Get array baselines + bl_ints = uvd.get_baseline_nums() # Only baselines with data + antpairs = [] + for bl in bl_ints: + a1, a2 = uvd.baseline_to_antnums(bl) + + # Exclude autos + if a1 != a2: + antpairs.append((a1, a2)) + + ants1, ants2 = zip(*antpairs) + + # Get array antenna locations + ant_ids_in_order = uvd.antenna_numbers + ant_ids = np.unique(np.concatenate((ants1, ants2))) + ants = {ant: list(uvd.antenna_positions[ant_ids_in_order == ant,:]) + for ant in ant_ids} + + # Put data in dict with named fields to avoid ambiguity + # Use built-in Python datatypes to help MPI + data_info = { + 'freqs': list(freqs), + 'lsts': list(lsts), + 'lat': float(lat), + 'lon': float(lon), + 'alt': float(alt), + 'antpairs': antpairs, + 'ants1': list(ants1), + 'ants2': list(ants2), + 'ants': ants + } + + # Return data immediately if MPI not enabled + if comm is None: + return data_info + else: + # Start with empty object for all other workers + data_info = None + + # Broadcast data to all workers + data_info = comm.bcast(data_info, root=0) + return data_info + + +def partial_load_uvdata(fname, freq_chunk, lst_chunk, antpairs, pol='xx'): + """ + Load data from a UVData file and unpack into the expected format. + Uses the partial loading feature of UVH5 files. + + Parameters: + fname (str): + Path to the data file, which should be a UVH5 file that supports + partial loading. + freqs (array_like): + Data frequencies that this worker should load, in MHz. + lsts (array_like): + Data LSTs that this worker should load, in radians. + bls (array_like): + Data baselines that this worker should load. These should be + provided as antenna pairs. + pol (str): + Which polarisation to retrieve from the data. + + Returns: + data (array_like): + Array of complex visibility data, with shape + `(Nbls, Nfreqs, Nlsts)`. + flags (array_like): + Array of integer flags to apply to the data. Same shape as the + `data` array. + """ + # Create new object + uvd = UVData() + uvd.read(fname, + frequencies=np.array(freq_chunk)*1e6, + lsts=lst_chunk, + bls=antpairs) + + # Get data and flags + data = np.zeros((len(antpairs), len(freq_chunk), len(lst_chunk)), + dtype=np.complex128) + flags = np.zeros((len(antpairs), len(freq_chunk), len(lst_chunk)), + dtype=np.int32) + + # Loop over baselines and extract data + for i, bl in enumerate(antpairs): + ant1, ant2 = bl + dd = uvd.get_data(ant1, ant2, pol) + print(dd.shape, ant1, ant2) + + # squeeze='full' collapses length-1 dimensions + data[i,:,:] = uvd.get_data(ant1, ant2, pol, squeeze='full').T + flags[i,:,:] = uvd.get_flags(ant1, ant2, pol, squeeze='full').T + return data, flags + + +def load_source_catalogue(fname, max_header_lines=20): + """ + Load a source catalogue in an expected standard text file format. The + file should be formatted as follows: + + - Line 0: Header (starting with #) with comma-separated list of field names + - Up to 20 optional header lines as key-value pairs separated by a comma, + e.g. `# ref_freq:300` + - Data as comma-separated values. + + The required fields are: + - `ra` and `dec`, equatorial coordinates in degrees + - `flux`, the flux at the reference frequency, in Jy + - `beta`, the spectral index of the power-law in frequency. + + Parameters: + fname (str): + Path to the catalogue file. This should be a comma-separated text + file with a header. + max_header_lines (int): + Maximum number of header lines to check for at the start of the + file. This only needs to be changed if you have more header lines + than the default maximum. If you have fewer header lines than + this, you don't need to change it. + + Returns: + cat (dict): + Dictionary containing arrays of values for each named field. + meta (dict): + Dictionary of metadata key:value pairs. + """ + # Define required fields + required = ['ra', 'dec', 'flux', 'beta'] + + # Get the header + with open(fname, 'r') as f: + # Read first line and remove whitespace and leading/trailing characters + header = f.readline() + header = header.replace("#", "").replace("\n", "").replace(" ", "") + fields = header.lower().split(",") + + # Get column number of each field + field_map = {field: j for field, j in enumerate(fields)} + + # Check for metadata in subsequent lines + metadata = {} + for i in range(max_header_lines): + # This will return a blank line if the end of the file is reached, + # so no need to test + line = f.readline() + if "#" and ":" in line: + line = line.replace("#", "").replace("\n", "").replace(" ", "") + vals = line.lower().split(":") + key, val = vals[0], vals[1] + metadata[key] = value + + # Check that required fields are present + for req in required: + if req not in field_map.keys(): + raise KeyError("Field '%s' was not found in catalogue file. " + "The following fields were found: %s" + % (req, str(field_map.keys()))) + + # Load data + d = np.loadtxt(fname, comments='#', delimiter=',') + assert d.shape[0] == len(field_map.keys()), \ + "Number of columns in data is different from header" + + # Re-pack catalogue into dict + cat = {} + for field in field_map: + cat[field] = d[field_map[field]] + + return cat, metadata + \ No newline at end of file diff --git a/hydra/region_sampler.py b/hydra/region_sampler.py index 98eb00b..4df0699 100644 --- a/hydra/region_sampler.py +++ b/hydra/region_sampler.py @@ -43,17 +43,17 @@ def get_diffuse_sky_model_pixels(freqs, nside=32, sky_model="gsm2016"): "lfsm", ], "Available sky models: 'gsm2008', 'gsm2016', 'haslam', 'lfsm'" if sky_model == "gsm2008": - model = pygdsm.GlobalSkyModel(freq_unit="MHz", include_cmb=False) + model = pygdsm.GlobalSkyModel(freq_unit="MHz") if sky_model == "gsm2016": try: - model = pygdsm.GlobalSkyModel2016(freq_unit="MHz", include_cmb=False) + model = pygdsm.GlobalSkyModel2016(freq_unit="MHz") except(AttributeError): # Different versions of pygdsm changed the API - model = pygdsm.GlobalSkyModel16(freq_unit="MHz", include_cmb=False) + model = pygdsm.GlobalSkyModel16(freq_unit="MHz") if sky_model == "haslam": - model = pygdsm.HaslamSkyModel(freq_unit="MHz", include_cmb=False) + model = pygdsm.HaslamSkyModel(freq_unit="MHz") if sky_model == "lfsm": - model = pygdsm.LowFrequencySkyModel(freq_unit="MHz", include_cmb=False) + model = pygdsm.LowFrequencySkyModel(freq_unit="MHz") model.generate(freqs_MHz) sky_maps = model.generated_map_data # (Nfreqs, Npix), should be in Kelvin diff --git a/hydra/tests/test_io.py b/hydra/tests/test_io.py new file mode 100644 index 0000000..2d2e121 --- /dev/null +++ b/hydra/tests/test_io.py @@ -0,0 +1,32 @@ + +import unittest + +import numpy as np +from hydra import io + +class TestIO(unittest.TestCase): + + def test_load_uvdata_metadata(self): + + import pyuvdata + import os + + # Get path to pyuvdata's built-in test data + fname = os.path.join(pyuvdata.data.DATA_PATH, "zen.2458432.34569.uvh5") + + # Check that the function runs + data_info = io.load_uvdata_metadata(comm=None, fname=fname) + + # Check that returned dict has the expected fields + expected_fields = ['freqs', 'lsts', 'lat', 'lon', 'alt', 'antpairs', + 'ants1', 'ants2', 'ants'] + for f in expected_fields: + self.assertTrue(f in data_info.keys()) + self.assertTrue(len(data_info) == len(expected_fields)) + + # Check that the data can be loaded + data, flags = io.partial_load_uvdata(fname, + freq_chunk=data_info['freqs'], + lst_chunk=data_info['lsts'], + antpairs=data_info['antpairs']) + \ No newline at end of file diff --git a/hydra/tests/test_vis_simulator.py b/hydra/tests/test_vis_simulator.py new file mode 100644 index 0000000..f2c5640 --- /dev/null +++ b/hydra/tests/test_vis_simulator.py @@ -0,0 +1,90 @@ + +import unittest + +import numpy as np +from hydra import vis_simulator, utils, example + +class TestVisSimulator(unittest.TestCase): + + def test_vis_sim_per_source(self): + + import pyuvsim + from matvis import conversions + + Nptsrc = 17 + + # Basic array layout + ant_pos = { + 0: (0., 0., 0.), + 1: (0., 14., 0.), + 2: (0., 0., 14.), + } + antpairs = [(0,1), (0,2), (1,2)] + lsts = np.linspace(0., 3., 5) # LSTs + + # Make dish diameter small so not all sources are far from mainlobe + beams = [pyuvsim.analyticbeam.AnalyticBeam('gaussian', diameter=6.) + for ant in ant_pos.keys()] + beams = [ + conversions.prepare_beam(beam, polarized=False, use_feed='x') + for beam in beams + ] + + # Settings + freqs = np.linspace(100., 200., 10) # MHz + + # RA goes from [0, 2 pi] and Dec from [-pi / 2, +pi / 2]. + ra, dec, amps = example.generate_random_ptsrc_catalogue( + Nptsrc=Nptsrc, + ra_bounds=(0., 2.*np.pi), + dec_bounds=(-0.5*np.pi, 0.5*np.pi), + logflux_bounds=(-1.0, 2.0) + ) + + # Get fluxes from ptsrc amplitude at ref. frequency + beta_ptsrc = -2.7 + fluxes = utils.get_flux_from_ptsrc_amp(amps, freqs, beta_ptsrc) + + # Source coordinate transform, from equatorial to Cartesian + crd_eq = conversions.point_source_crd_eq(ra, dec) + + # Get coordinate transforms as a function of LST + latitude = np.deg2rad(-30.7215) + eq2tops = np.array([conversions.eci_to_enu_matrix(lst, latitude) for lst in lsts]) + antpos = np.array([ant_pos[k] for k in ant_pos.keys()]) + + for j, freq in enumerate(freqs): + + # Run vis_sim_per_source in mode that includes sqrt(fluxes) in calculation + vis1 = vis_simulator.vis_sim_per_source( + antpos, + freq*1e6, + eq2tops, + crd_eq, + fluxes[:,j], + beam_list=beams, + precision=2, + polarized=False, + force_no_beam_sqrt=True, + apply_fluxes_afterwards=False, + ) + self.assertTrue(np.all(~np.isnan(vis1))) + + # Run vis_sim_per_source in mode that applies fluxes after calculating + # the visibility response + vis1a = vis_simulator.vis_sim_per_source( + antpos, + freq*1e6, + eq2tops, + crd_eq, + fluxes[:,j], + beam_list=beams, + precision=2, + polarized=False, + force_no_beam_sqrt=True, + apply_fluxes_afterwards=True, + ) + self.assertTrue(np.all(~np.isnan(vis1a))) + + # Both methods should match + self.assertTrue(np.allclose(vis1, vis1a)) \ No newline at end of file diff --git a/hydra/utils.py b/hydra/utils.py index 982450f..f810ea6 100644 --- a/hydra/utils.py +++ b/hydra/utils.py @@ -2,19 +2,15 @@ from matvis import conversions import pyuvdata from pyuvsim import AnalyticBeam +from scipy.interpolate import RegularGridInterpolator -""" -# Terminal colour codes -#terminal_ = '\033[95m' -terminal_blue = '\033[94m' -terminal_cyan = '\033[96m' -terminal_green = '\033[92m' -terminal_yellow = '\033[93m' -#FAIL = '\033[91m' -terminal_endc = '\033[0m' -terminal_bold = '\033[1m' -terminal_ul = '\033[4m' -""" +import pygdsm +from astropy.coordinates import SkyCoord, AltAz, EarthLocation, Galactic, ICRS +from astropy.time import Time +import astropy.units as u +import healpy as hp + +C = 299792.458 def flatten_vector(v, reduced_idxs=None): @@ -705,6 +701,99 @@ def partial_fourier_basis_2d_from_nmax( return basis_fns, kfreq, ktime +def tsky_from_galaxy_model(comm, latitude=-30.7214, longitude=21.4280, height=77., + times_per_hour=12, freq_range=(50., 2), Nfreqs=10, + dish_diameter=14.): + """ + Predict the sky temperature from a global sky model as a function of + frequency and beam FWHM. + + Parameters: + xx + + Returns: + xx + """ + myid = 0 + if comm is not None: + myid = comm.Get_rank() + + sky_frame = Galactic() + + # Generate sky map + _sky_map = pygdsm.GlobalSkyModel16() + + # Set up HERA observatory location + location = EarthLocation(lat=latitude * u.deg, + lon=longitude * u.deg, + height=height * u.m) + obstime0 = '2026-01-01T18:00:00.00000' # need to choose a reference time + times = Time(obstime0, format='isot', scale='utc') \ + + np.linspace(0., 24., 24*times_per_hour) * u.hour + + # Get LSTs + lsts = np.array([t.sidereal_time('apparent', 'greenwich').rad for t in times]) + idxs_lsts_sorted = np.argsort(lsts) + + # Get pixel index of zenith pointing at each time + pixel_idxs = [] + for i, t in enumerate(times): + c = AltAz(az=0.*u.deg, alt=90.*u.deg, obstime=t, location=location).transform_to(sky_frame) + idx = hp.ang2pix(hp.npix2nside(map64.size), theta=c.b.rad + 0.5*np.pi, phi=c.l.rad) + pixel_idxs.append(idx) + + # Convert alt/az of zenith to sky coords for each frequency + freqs = np.linspace(freq_range[0], freq_range[1], Nfreqs) + temperature_vals = np.zeros((freqs.size, lsts.size)) + for j, freq in enumerate(freqs): + + # Calculate approx. beam FWHM in radians + fwhm = (C / (freq*1e6)) / dish_diameter # radians + + # Get sky map at this frequency + sky_map = _sky_map.generate(freq) + + # Get value of pixel at zenith at each time + pvals = np.zeros(len(times)) + for i, t in enumerate(times): + # Smooth the sky map at this freq. with the approximate FWHM + pvals[i] = hp.smoothing(sky_map, fwhm=fwhm)[pixel_idxs[i]] + pvals = np.array(pvals) + temperature_vals[j,:] = pvals[idxs_lsts_sorted] + + # Build 2D interpolator and return + temp_interp = RegularGridInterpolator((freqs, lsts[idxs_lsts_sorted]), + temperature_vals, + method='linear', + bounds_error=True) + return temp_interp + + +def noise_from_autos(ants, auto_vis, data_shape): + """ + + Parameters: + ants (list): + List of antenna indices. + auto_vis (list of array_like): + List of auto-baseline visibilities. + data_shape (tuple): + Shape of the chunk of the data array handled by this worker. + + Returns: + Ninv (array_like): + Inverse noise variance array. + noise_chunk (array_like): + Array containing a thermal noise realisation with the same + shape as the chunk of data handled by this worker. + """ + raise NotImplementedError("Not yet implemented.") + noise_chunk = sigma_noise * np.sqrt(0.5) \ + * ( 1.0 * np.random.randn(*data_chunk.shape) \ + + 1.j * np.random.randn(*data_chunk.shape)) + return noise_chunk + + def status(myid, message, colour=None): """ Print a status message. diff --git a/hydra/vis_simulator.py b/hydra/vis_simulator.py index 6dae639..e279fc4 100644 --- a/hydra/vis_simulator.py +++ b/hydra/vis_simulator.py @@ -112,6 +112,7 @@ def vis_sim_per_source( beam_idx: Optional[np.ndarray] = None, subarr_ant=None, force_no_beam_sqrt=False, + apply_fluxes_afterwards=True, ): """ Calculate visibility from an input intensity map and beam model. This is @@ -165,6 +166,11 @@ def vis_sim_per_source( with a particular antenna. force_no_beam_sqrt (bool): Do not take the square root of a beam even if it's a power beam. + apply_fluxes_afterwards (bool): + If True, use a unit flux for each source in each frequency channel, + and then apply the true fluxes afterwards. This is useful for sky + models with negative values. If False, the 'standard' `matvis` + method of taking the sqrt of the fluxes is used. Returns: vis (array_like): @@ -206,7 +212,15 @@ def vis_sim_per_source( # Intensity distribution (sqrt) and antenna positions. Does not support # negative sky. Factor of 0.5 accounts for splitting Stokes I between # polarization channels - Isqrt = np.sqrt(0.5 * I_sky).astype(real_dtype) + if np.any(I_sky < 0.) and not apply_fluxes_afterwards: + raise ValueError("Sky model has negative values; use the " + "`apply_fluxes_afterwards` setting to circumvent this.") + if apply_fluxes_afterwards: + # Unit flux in each channel, for each source (include factor of sqrt(0.5)) + Isqrt = np.sqrt(0.5).astype(real_dtype) + else: + Isqrt = np.sqrt(0.5 * I_sky).astype(real_dtype) + antpos = antpos.astype(real_dtype) ang_freq = 2.0 * np.pi * freq @@ -240,6 +254,7 @@ def vis_sim_per_source( # Primary beam pattern using direct interpolation of UVBeam object az, za = conversions.enu_to_az_za(enu_e=tx, enu_n=ty, orientation="uvbeam") + for i, bm in enumerate(beam_list): spw_axis_present = utils.get_beam_interp_shape(bm) kw = ( @@ -289,11 +304,11 @@ def vis_sim_per_source( # Complex voltages. # v *= Isqrt[above_horizon] - v *= Isqrt[:] + v *= Isqrt v[:, ~above_horizon] *= 0.0 # zero-out sources below the horizon # Compute visibilities using product of complex voltages (upper triangle). - # Input arrays have shape (Nax, Nfeed, [Nants], Nsrcs + # Input arrays have shape (Nax, Nfeed, [Nants], Nsrcs) v = A_s[:, :, beam_idx] * v[np.newaxis, np.newaxis, :] # If a subarray is requested, only compute the visibilities that involve @@ -317,6 +332,11 @@ def vis_sim_per_source( optimize=True, ) + # If requested, apply fluxes once the per-source visibility response has been calculated + if apply_fluxes_afterwards: + # vis has shape (NAXES, NFEED, NTIMES, NANTS, NANTS, NSRCS) + vis *= I_sky[np.newaxis,np.newaxis,np.newaxis,np.newaxis,np.newaxis,:] + # Return visibilities with or without multiple polarization channels return vis if polarized else vis[0, 0]