Source code for caltrig.core.backend

# This is where all the ultility functions are stored

# The first function is to load in the data
import dask.array as darr
import numpy as np
import pandas as pd
import xarray as xr
import zarr as zr
from os.path import isdir, isfile
from os.path import join as pjoin
from os import listdir
from typing import Callable, List, Optional, Union, Dict
import os
from pathlib import Path
from uuid import uuid4
import rechunker
import dask as da
from dask.delayed import optimize as default_delay_optimize
import json
import shutil
from dask.diagnostics import ProgressBar

from .caiman_utils import detrend_df_f, minian_to_caiman

from scipy.signal import welch, savgol_filter
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.ndimage.measurements import center_of_mass
from skimage.measure import find_contours

from matplotlib import cm
import matplotlib.pyplot as plt

import configparser
import time
import datetime

[docs]def open_minian( dpath: str, post_process: Optional[Callable] = None, return_dict=True ) -> Union[dict, xr.Dataset]: """ Taken from https://github.com/denisecailab/minian/blob/f64c456ca027200e19cf40a80f0596106918fd09/minian/utilities.py#L278. The current version of minian has outdated dependencies and is not compatible with this project. Load an existing minian dataset. If `dpath` is a file, then it is assumed that the full dataset is saved as a single file, and this function will directly call :func:`xarray.open_dataset` on `dpath`. Otherwise if `dpath` is a directory, then it is assumed that the dataset is saved as a directory of `zarr` arrays, as produced by :func:`save_minian`. This function will then iterate through all the directories under input `dpath` and load them as `xr.DataArray` with `zarr` backend, so it is important that the user make sure every directory under `dpath` can be load this way. The loaded arrays will be combined as either a `xr.Dataset` or a `dict`. Optionally a user-supplied custom function can be used to post process the resulting `xr.Dataset`. Parameters ---------- dpath : str The path to the minian dataset that should be loaded. post_process : Callable, optional User-supplied function to post process the dataset. Only used if `return_dict` is `False`. Two arguments will be passed to the function: the resulting dataset `ds` and the data path `dpath`. In other words the function should have signature `f(ds: xr.Dataset, dpath: str) -> xr.Dataset`. By default `None`. return_dict : bool, optional Whether to combine the DataArray as dictionary, where the `.name` attribute will be used as key. Otherwise the DataArray will be combined using `xr.merge(..., compat="no_conflicts")`, which will implicitly align the DataArray over all dimensions, so it is important to make sure the coordinates are compatible and will not result in creation of large NaN-padded results. Only used if `dpath` is a directory, otherwise a `xr.Dataset` is always returned. By default `False`. Returns ------- ds : Union[dict, xr.Dataset] The resulting dataset. If `return_dict` is `True` it will be a `dict`, otherwise a `xr.Dataset`. See Also ------- xarray.open_zarr : for how each directory will be loaded as `xr.DataArray` xarray.merge : for how the `xr.DataArray` will be merged as `xr.Dataset` """ if isfile(dpath): ds = xr.open_dataset(dpath).chunk() elif isdir(dpath): dslist = [] for d in listdir(dpath): arr_path = pjoin(dpath, d) if isdir(arr_path): if not arr_path.endswith("backup"): arr = list(xr.open_zarr(arr_path, consolidated=False).values())[0] arr.data = darr.from_zarr( os.path.join(arr_path, arr.name), inline_array=True ) dslist.append(arr) if return_dict: ds = {d.name: d for d in dslist} else: ds = xr.merge(dslist, compat="override") if (not return_dict) and post_process: ds = post_process(ds, dpath) return ds
[docs]def save_minian( var: xr.DataArray, dpath: str, meta_dict: Optional[dict] = None, overwrite=False, chunks: Optional[dict] = None, compute=True, mem_limit="500MB", ) -> xr.DataArray: """ Taken from https://github.com/denisecailab/minian/blob/f64c456ca027200e19cf40a80f0596106918fd09/minian/utilities.py#L440. The current version of minian has outdated dependencies and is not compatible with this project, hence the function has been copied here. Save a `xr.DataArray` with `zarr` storage backend following minian conventions. This function will store arbitrary `xr.DataArray` into `dpath` with `zarr` backend. A separate folder will be created under `dpath`, with folder name `var.name + ".zarr"`. Optionally metadata can be retrieved from directory hierarchy and added as coordinates of the `xr.DataArray`. In addition, an on-disk rechunking of the result can be performed using :func:`rechunker.rechunk` if `chunks` are given. Parameters ---------- var : xr.DataArray The array to be saved. dpath : str The path to the minian dataset directory. meta_dict : dict, optional How metadata should be retrieved from directory hierarchy. The keys should be negative integers representing directory level relative to `dpath` (so `-1` means the immediate parent directory of `dpath`), and values should be the name of dimensions represented by the corresponding level of directory. The actual coordinate value of the dimensions will be the directory name of corresponding level. By default `None`. overwrite : bool, optional Whether to overwrite the result on disk. By default `False`. chunks : dict, optional A dictionary specifying the desired chunk size. The chunk size should be specified using :doc:`dask:array-chunks` convention, except the "auto" specifiication is not supported. The rechunking operation will be carried out with on-disk algorithms using :func:`rechunker.rechunk`. By default `None`. compute : bool, optional Whether to compute `var` and save it immediately. By default `True`. mem_limit : str, optional The memory limit for the on-disk rechunking algorithm, passed to :func:`rechunker.rechunk`. Only used if `chunks` is not `None`. By default `"500MB"`. Returns ------- var : xr.DataArray The array representation of saving result. If `compute` is `True`, then the returned array will only contain delayed task of loading the on-disk `zarr` arrays. Otherwise all computation leading to the input `var` will be preserved in the result. Examples ------- The following will save the variable `var` to directory `/spatial_memory/alpha/learning1/minian/important_array.zarr`, with the additional coordinates: `{"session": "learning1", "animal": "alpha", "experiment": "spatial_memory"}`. >>> save_minian( ... var.rename("important_array"), ... "/spatial_memory/alpha/learning1/minian", ... {-1: "session", -2: "animal", -3: "experiment"}, ... ) # doctest: +SKIP """ dpath = os.path.normpath(dpath) Path(dpath).mkdir(parents=True, exist_ok=True) ds = var.to_dataset() if meta_dict is not None: pathlist = os.path.split(os.path.abspath(dpath))[0].split(os.sep) ds = ds.assign_coords( **dict([(dn, pathlist[di]) for dn, di in meta_dict.items()]) ) md = {True: "a", False: "w-"}[overwrite] fp = os.path.join(dpath, var.name + ".zarr") if overwrite: try: shutil.rmtree(fp) except FileNotFoundError: pass arr = ds.to_zarr(fp, compute=compute, mode=md) if (chunks is not None) and compute: chunks = {d: var.sizes[d] if v <= 0 else v for d, v in chunks.items()} dst_path = os.path.join(dpath, str(uuid4())) temp_path = os.path.join(dpath, str(uuid4())) with da.config.set( array_optimize=darr.optimization.optimize, delayed_optimize=default_delay_optimize, ): zstore = zr.open(fp) rechk = rechunker.rechunk( zstore[var.name], chunks, mem_limit, dst_path, temp_store=temp_path ) rechk.execute() try: shutil.rmtree(temp_path) except FileNotFoundError: pass arr_path = os.path.join(fp, var.name) for f in os.listdir(arr_path): os.remove(os.path.join(arr_path, f)) for f in os.listdir(dst_path): os.rename(os.path.join(dst_path, f), os.path.join(arr_path, f)) os.rmdir(dst_path) if compute: arr = xr.open_zarr(fp)[var.name] arr.data = darr.from_zarr(os.path.join(fp, var.name), inline_array=True) return arr
[docs]def overwrite_xarray( varr: xr.DataArray, dpath: str, retrieve: bool = False, ) -> xr.DataArray: """ Save an xarray DataArray to a zarr file. This function creates a temporary zarr file in the same directory as the existing zarr file, and then renames the temporary file to the original. This is due to the fact that certain errors would occur whenever I would try to save the zarr file directly to the original file, loading the zarr array into memory would also cause the same error. This is a workaround to avoid the error. Parameters ---------- varr : xr.DataArray The xarray DataArray that should be saved. dpath : str The path to the zarr file that should be saved. retrieve : bool, optional Whether the saved xarray DataArray should be read from the zarr file. By default `False`. Returns ------- arr : xr.DataArray The saved xarray DataArray. It will identical to the input `varr` but it will read from a new zarr file. """ dpath = os.path.normpath(dpath) fp_temp = os.path.join(dpath, varr.name + "_temp.zarr") fp_orig = os.path.join(dpath, varr.name + ".zarr") arr = varr.to_zarr(fp_temp, compute=True, mode="w", consolidated=False) try: shutil.rmtree(fp_orig) except FileNotFoundError: pass # Rename the temp file to the original file _safe_rename(fp_temp, fp_orig) if retrieve: arr = xr.open_zarr(fp_orig)[varr.name] arr.data = darr.from_zarr(os.path.join(fp_orig, varr.name), inline_array=True) return arr
def _safe_rename(fp_temp, fp_orig, retries=5, delay=0.1): # Occasionally on Windows due to permissions errors, the file rename will fail. # This function will attempt to rename the file multiple times with a delay to avoid this. for _ in range(retries): try: os.rename(fp_temp, fp_orig) return except PermissionError: time.sleep(delay) raise PermissionError(f"Failed to rename {fp_temp} to {fp_orig} after {retries} attempts")
[docs]def delete_xarray( dpath: str, var_name: str = "M") -> None: """ Delete the specified xarray DataArray by removing the zarr file. The function serves as a convenience method to deal with "Missing" DataArrays. It will be necessary to call this whenever all missing cells have been removed. Parameters ---------- dpath : str The path to the zarr file that should be deleted. var_name : str, optional The name of the DataArray that should be deleted. By default "M", as we expect this to be the missing data array. """ fp = os.path.join(dpath, var_name + ".zarr") try: shutil.rmtree(fp) except FileNotFoundError: pass
[docs]class Event: ''' An event in this context refers to external behavioral events, such as RNFs, ALPs, ILPs, etc... This class also contains various methods to extract relevant information for each Event. Parameters ---------- event_type : str The type of behavioral event, e.g. "ALP", "ILP", "RNF", etc... data : xr.DataArray The data array that contains the all CNMF output related data. timesteps : List[int] A list of timesteps where the event occurs. ''' def __init__( self, event_type:str, data:xr.DataArray, timesteps:List[int] ): self.has_param = False self.data = data self.event_type = event_type self.delay: float self.duration: float self.switch = False self.timesteps = timesteps self.values:dict self.binSize:int self.preBinNum: int self.postBinNum: int self.binList: list def set_binSize(self, binSize: int): self.binSize = binSize def set_preBinNum(self, preBinNum: int): self.preBinNum = preBinNum def set_postBinNum(self, postBinNum: int): self.postBinNum = postBinNum def get_binList(self, event_frame: int, preBinNum: int, postBinNum: int, binSize: int,value_type: int): binList = [] for i in range(-preBinNum,postBinNum): bin = self.get_interval_section(event_frame, binSize, i*binSize, 0, value_type)[0] binList.append(bin) return binList def set_delay_and_duration(self, delay:float, duration:float): self.delay = delay self.duration = duration self.has_param = True def set_switch(self, switch : bool = True): self.switch = switch
[docs] def get_section(self, event_frame: int, duration: float, delay: float = 0.0, type: str = "C") -> xr.Dataset: """ Return the selection of the data that is within the given time frame. duration indicates the number of frames. Parameters ---------- event_frame: int event time stamp duration : float last time (seconds) delay: float before or after (seconds) """ # duration is in seconds convert to ms duration *= 1000 delay *= 1000 max_length = len(self.data['Time Stamp (ms)']) if delay > 0: frame_gap = 1 while self.data['Time Stamp (ms)'][event_frame + frame_gap] - self.data['Time Stamp (ms)'][event_frame] < delay: frame_gap += 1 event_frame += frame_gap elif delay < 0: frame_gap = -1 while self.data['Time Stamp (ms)'][event_frame + frame_gap] - self.data['Time Stamp (ms)'][event_frame] > delay and event_frame + frame_gap > 0: frame_gap -= 1 event_frame += frame_gap frame_gap = 1 while self.data['Time Stamp (ms)'][event_frame + frame_gap] - self.data['Time Stamp (ms)'][event_frame] < duration and event_frame + frame_gap < max_length-1: frame_gap += 1 if type in self.data: return self.data[type].sel(frame=slice(event_frame, event_frame+frame_gap)) , event_frame,event_frame+frame_gap else: print("No %s data found in minian file" % (type)) return None
[docs] def get_interval_section(self, event_frame: int, duration: float, delay: float = 0.0, interval:int = 100, type: str = "C") -> xr.Dataset: ''' Return the selection of the data that is within the given time frame. Parameters ---------- event_frame: int Frame at which the event occurs duration : float Duration of the event in seconds delay: float Specifies how much time from the event frame should be included in the selection. If delay is positive, then the selection will start from the event frame + delay. If delay is negative, then the selection will start from the event frame - delay. interval: int The interval at which the data should be sampled. This is in milliseconds. type: str Specfies which data type to extract from the minian file. Default is "C". ''' # duration is in seconds convert to ms integrity = True duration *= 1000 delay *= 1000 frame_list = [] max_length = len(self.data['Time Stamp (ms)']) if delay > 0: frame_gap = 0 while self.data['Time Stamp (ms)'][event_frame + frame_gap] - self.data['Time Stamp (ms)'][event_frame] < delay: if (event_frame + frame_gap) < (max_length-1): frame_gap += 1 else: integrity = False break event_frame += frame_gap elif delay < 0: frame_gap = 0 while self.data['Time Stamp (ms)'][event_frame + frame_gap] - self.data['Time Stamp (ms)'][event_frame] > delay: if(event_frame + frame_gap > 0): frame_gap -= 1 else: integrity = False break event_frame += frame_gap frame_gap = 0 time_flag = self.data['Time Stamp (ms)'][event_frame] frame_list.append(event_frame) while self.data['Time Stamp (ms)'][event_frame + frame_gap] - self.data['Time Stamp (ms)'][event_frame] < duration and event_frame + frame_gap < max_length-1: if self.data['Time Stamp (ms)'][event_frame + frame_gap]-time_flag > interval: time_flag = self.data['Time Stamp (ms)'][event_frame + frame_gap] frame_list.append(event_frame + frame_gap) frame_gap += 1 if type in self.data: return self.data[type].sel(frame = frame_list) , event_frame, event_frame + frame_gap, integrity else: print("No %s data found in minian file" % (type)) return None
[docs] def set_values(self): """ Update the values dictionary with the values of the event data and the corresponding windows. """ values={} event_list= [] windows = [] if self.switch == False: self.values=values return else: for i in self.timesteps: single_event, start_frame, end_frame = self.get_section(i,self.duration,self.delay) event_list.append(single_event) windows.append([start_frame, end_frame]) for i in self.data['unit_ids']: values[i] = np.array([]) for i in event_list: for j in i.coords['unit_id'].values: values[j] = np.r_['-1', values[j], np.array(i.sel(unit_id=j).values)] self.values = values self.windows = windows
[docs]class DataInstance: ''' This class is used to store all the data related to a single experiment/recording. This includes all CNMF output data, behavioral/timestamp data and video data. Parameters ---------- config_path : str The path to the configuration file that contains the paths to the minian, behavior and video files. Attributes ---------- events_type : List[str] A list of all the event types that are supported by the program. mouseID : str The mouse ID for the experiment. day : str The day of the experiment. session : str The session of the experiment. group : str The group of the experiment. data : dict A dictionary that contains all the CNMF output data. The keys are 'A', 'C', 'S', 'E', 'b', 'f', 'DFF', 'YrA', 'M', 'timestamp(ms)'. video_data : dict A dictionary that contains all the video data. The keys are 'Y_fm_chk', 'varr', 'Y_hw_chk', 'behavior_video'. events : dict A dictionary that contains all the behavior event data. The keys are the event types and the values are Event objects. ''' distance_metric_list = ['euclidean','cosine'] # Static variable so parameters can be read before initiating instance def __init__( self, config_path: str ): self.events_type = ['ALP','ILP','RNF','ALP_Timeout'] self.config_path = config_path self.mouseID : str self.day : str self.session: str self.group: str self.data: dict # Original data, key:'A', 'C', 'S','unit_ids' self.video_data: dict # Video data, key: 'Y_fm_chk', 'varr', 'Y_hw_chk', 'behavior_video' self.events: dict # {"ALP": Event, "ILP" : Event, "RNF": Event} self.outliers_list: List[int] = [] self.centroids: dict self.centroids_to_cell_ids: dict self.centroids_max: dict self.load_data(config_path=config_path) self.no_of_clusters = 4 self.distance_metric = 'euclidean' self.missed_signals = {} self.load_events(self.events_type) self.changed_events = False # This is necessary to in the case of recalculating values for sda_widgets self.noise_values = {} self.cell_ids_to_groups = {} # This differs from self.group as it is used for preselected groups by the user self.load_cell_groups() # Load cell groups from JSON if it exists # Create the default image self.clustering_result = {"basic": {"image": np.stack((self.data['A'].sum("unit_id").values,)*3, axis=-1)}}
[docs] def add_missed(self, A: np.array): """ Adds a missed cell to the data. The missed cell is represented by a footprint mask and added to the M data array and saved to the data folder. A unique missed_id is assigned to the missed cell. Parameters ---------- A : np.array The footprint mask of the missed cell. """ id = max(self.data["M"].coords["missed_id"].values) + 1 if self.data["M"] is not None else 1 M = xr.DataArray(np.expand_dims(A, axis=0), dims=["missed_id", "height", "width"], coords={"missed_id": [id], "height": self.data['A'].coords["height"].values, "width": self.data['A'].coords["width"].values}, name="M") if self.data["M"] is not None: M_old = self.data["M"].load() M = xr.concat([M_old, M], dim="missed_id") self.data["M"] = overwrite_xarray(M, self.cnmf_path, retrieve=True) return id
[docs] def remove_missed(self, ids: List[int]): """ Removes the missed cells from the data. The missed cells are identified by their missed_id. Parameters ---------- ids : List[int] A list of missed_ids that should be removed. """ M = self.data["M"].load() M = M.drop_sel(missed_id=ids) if M.size == 0: delete_xarray(self.cnmf_path, "M") self.data["M"] = None else: self.data["M"] = overwrite_xarray(M, self.cnmf_path, retrieve=True)
def parse_file(self, config_path):# set up configure file config = configparser.ConfigParser() try: config.read(config_path) except: print("ERROR: ini file is either not in the correct format or empty, did you make sure to save the ini file?") if len(config.sections())==1 and config.sections()[0]=='Session_Info': mouseID = config['Session_Info']['mouseID'] day = config['Session_Info']['day'] session = config['Session_Info']['session'] group = config['Session_Info']['group'] data_path = config['Session_Info']['data_path'] behavior_path = config['Session_Info']['behavior_path'] video_path = config['Session_Info']['video_path'] return mouseID, day, session, group, data_path, behavior_path, video_path else: print("Error! Section name should be 'Session_Info'!") def contains(self, video_type, data_keys): for key in data_keys: if video_type in key: return True, key return False, None def load_videos(self): data = open_minian(self.video_path) video_types = ["Y_fm_chk", "varr", "Y_hw_chk", "behavior_video"] video_data = {} for video_type in video_types: exists, data_type = self.contains(video_type, list(data.keys())) if exists: video_data[video_type] = data[data_type] else: print("No %s data found in video folder" % (video_type)) if video_type == "Y_hw_chk" and "Y_fm_chk" in data: print("Creating Y_hw_chk from Y_fm_chk. This may take a while.") Y_hw_chk = save_minian( data["Y_fm_chk"].rename("Y_hw_chk"), self.video_path, overwrite=True, chunks={"frame": -1, "height": 32, "width": 32}, ) video_data[video_type] = Y_hw_chk print("Done creating Y_hw_chk") self.video_data = video_data
[docs] def load_data(self, config_path): """ Load the data from the data path specified in the config file. Parameters ---------- config_path : str The path to the configuration file that contains the paths to the minian, behavior and video files. """ mouseID, day, session, group, cnmf_path, behavior_path, video_path = self.parse_file(config_path) self.mouseID = mouseID self.day = day self.session = session self.group = group self.cnmf_path = cnmf_path self.video_path = video_path behavior_data = pd.read_csv(behavior_path,sep=',') data_types = ['RNF', 'ALP', 'ILP', 'ALP_Timeout','Time Stamp (ms)'] self.data = {} for dt in data_types: if dt in behavior_data: self.data[dt] = behavior_data[dt] else: print("No %s data found in minian file" % (dt)) self.data[dt] = None data = open_minian(cnmf_path) data_types = ['A', 'C', 'S', 'E', 'b', 'f', 'DFF', 'YrA', 'M','timestamp(ms)'] timestamp = behavior_data[["Time Stamp (ms)"]] timestamp.index.name = "frame" da_ts = timestamp["Time Stamp (ms)"].to_xarray() data['timestamp(ms)'] = da_ts for dt in data_types: if dt in data: self.data[dt] = data[dt] if "unit_id" in self.data[dt].coords: self.data[dt] = self.data[dt].dropna(dim="unit_id") # Safe guard against deprecated E standard and erroneous E data if dt == 'E': self.data[dt] = self.data[dt].fillna(0).where(self.data[dt] == 0, 1) else: print("No %s data found in minian file" % (dt)) self.data[dt] = None self.unit_id_consistency() self.data['unit_ids'] = np.sort(self.data['C'].coords['unit_id'].values) self.config_path = config_path self.data['filtered_C'] = self.get_filtered_C cells = self.data['unit_ids'] cent = self.centroid(self.data['A']) cent_max = self.centroid_max(self.data['A']) self.centroids = {} self.centroids_max = {} for i in cells: self.centroids[i] = tuple(cent.loc[cent['unit_id'] == i].values[0][1:]) self.centroids_max[i] = tuple(cent_max.loc[cent_max['unit_id'] == i].values[0][1:]) self.centroids_to_cell_ids = {v: k for k, v in self.centroids.items()} output_dpath = "/N/project/Cortical_Calcium_Image/analysis" if session is None: self.output_path = os.path.join(output_dpath, mouseID,day) else: self.output_path = os.path.join(output_dpath, mouseID,day,session) if(os.path.exists(self.output_path) == False): os.makedirs(self.output_path)
[docs] def unit_id_consistency(self): """ This function will check if the unit_ids are consistent across all data arrays. If not, it will drop the inconsistent unit_ids from all data arrays. This will be achieved by taking the intersection of all unit_ids and then filtering the data arrays. """ unit_id_list = [] keys = ['A', 'C', 'S', 'YrA'] for key in keys: unit_id_list.append(set(self.data[key].coords["unit_id"].values)) intersection = set.intersection(*unit_id_list) for key in keys: self.data[key] = self.data[key].sel(unit_id=list(intersection))
[docs] def get_filtered_C(self) -> None: """ This function will filter the C data array by multiplying it with the normalized S data array. This has the effect of removing non-event related signals from the C data array. """ normalized_S = xr.apply_ufunc( self.normalize_events, self.data['S'].chunk(dict(frame=-1, unit_id="auto")), input_core_dims=[["frame"]], output_core_dims=[["frame"]], dask="parallelized", output_dtypes=[self.data['E'].dtype], ) filtered_C = self.data['C'] * normalized_S return filtered_C
def normalize_events(self, a: np.ndarray) -> np.ndarray: # All positive values will be set to 1 a = a.copy() a[a > 0] = 1 return a def get_pdf_format(self, unit_ids, cluster, path): contours = [] for id in unit_ids: cell = self.data['A'].sel(unit_id=id).values yaoying_param = 6 thresholded_roi = 1 * cell > (np.mean(cell) + yaoying_param * np.std(cell)) contours.append(find_contours(thresholded_roi, 0)[0]) fig, ax = plt.subplots(figsize=(10, 10)) cluster = "all" if cluster == 0 else cluster ax.imshow(self.clustering_result[cluster]['image']) for cell, unit_id in zip(contours, unit_ids): ax.plot(cell[:, 1], cell[:, 0], color='xkcd:azure', alpha=0.5, linewidth=1) ax.text(np.mean(cell[:, 1]), np.mean(cell[:, 0]), unit_id, color='xkcd:azure', ha='center', va='center', fontsize="small") ax.axis('off') fig.savefig(path)
[docs] def get_timestep(self, type: str): """ Return a list that contains contains the a list of the frames where the ALP occurs. Parameters ---------- type : str The type of event to extract the timesteps from. """ return np.flatnonzero(self.data[type])
def save_justifications(self, justifications): filename = "justifications-{self.mouseID}-{self.day}-{self.session}.json" with open(os.path.join(self.output_path, filename), "w") as f: json.dump(justifications, f) def load_justifications(self): filename = "justifications-{self.mouseID}-{self.day}-{self.session}.json" if os.path.exists(os.path.join(self.output_path, filename)): with open(os.path.join(self.output_path, filename), "r") as f: return json.load(f) else: return {} def load_events(self, keys): events = {} for key in keys: events[key] = Event(key,self.data,self.get_timestep(key)) events[key].switch = True self.events = events def get_cell_sizes(self): return (self.data["A"] > 0).sum(["height", "width"]).compute() def get_missed_signal(self, missed_id: int): if missed_id in self.missed_signals: return self.missed_signals[missed_id] mask = self.data["M"].sel(missed_id=missed_id).compute() # Extract the dimensions of the mask from x and y axis x_range, y_range = self.get_mask_dimensions(mask) Y = self.video_data["Y_hw_chk"].sel(height=y_range, width=x_range).compute() mask_small = mask.sel(height=y_range, width=x_range) averaged_signal = (Y * mask_small).sum(["height", "width"]).compute() self.missed_signals[missed_id] = averaged_signal.values return self.missed_signals[missed_id] def get_mask_dimensions(self, mask): # Get mask dimensions but only for positive values mask_x = mask.any("height").values mask_y = mask.any("width").values # Get indices in range x_range = np.where(mask_x)[0] y_range = np.where(mask_y)[0] return x_range, y_range def get_total_transients(self, unit_id=None): # Diff won't capture the first transient, so we add 1 to the sum if # the first frame is 1 if unit_id is None: total_transients = (self.data["E"].diff(dim="frame") == 1).sum(dim="frame") + (self.data["E"].isel(frame=0)) return total_transients.compute() else: total_transients = (self.data["E"].sel(unit_id = unit_id).diff(dim="frame") == 1).sum(dim="frame") + (self.data["E"].sel(unit_id = unit_id).isel(frame=0)) return total_transients.compute()
[docs] def get_average_peak_dff(self): """ Calculate the average peak dff for each cell. The peak dff is calculated by taking the maximum value of the DFF signal of each transient. Then calculate the average of all the peak dffs. Returns ------- results : dict A dictionary where the keys are the unit_ids and the values are the average peak dffs. """ E = self.data["E"].compute() DFF = self.data["DFF"].compute() results = {} for unit_id in self.data["unit_ids"]: peaks = [] transients = E.sel(unit_id=unit_id).values.nonzero()[0] if transients.any(): # Split up the indices into groups transients = np.split(transients, np.where(np.diff(transients) != 1)[0]+1) # Now Split the indices into pairs of first and last indices transients = [(indices_group[0], indices_group[-1]+1) for indices_group in transients] for start, stop in transients: peaks.append(DFF.sel(unit_id=unit_id, frame=slice(start, stop)).max().values.item()) if peaks: results[unit_id] = np.mean(peaks) return results
def get_total_rising_frames(self): return (self.data["E"] != 0).sum(dim="frame").compute() def get_std(self): return self.data["DFF"].std(dim="frame").compute()
[docs] def get_mad(self, id=None): """ Get the median absolute deviation. Parameters ---------- id : int The unit_id of the cell for which the MAD should be calculated. If None, then the MAD will be calculated for all cells. """ if id is None: median = self.data["DFF"].median(dim="frame").compute() mad = abs(self.data["DFF"] - median).median(dim="frame").compute() else: # It throws an error when trying to extract the median from a single cell therefore convert to numpy first dff = self.data["DFF"].sel(unit_id=id).values median = np.nanmedian(dff) mad = np.nanmedian(abs(dff - median)) return (1 / 0.6745) * mad
[docs] def get_savgol(self, id, params={}): """ Calculate the Savitzky-Golay filter for the DFF signal, this will be used to estimate the noise. Parameters ---------- id : int The unit_id of the cell for which the Savitzky-Golay filter should be calculated. params : dict A dictionary that contains the parameters for the Savitzky-Golay filter. The parameters are: * win_len : int, optional The length of the filter window. Must be an odd integer. Default is 10. * poly_order : int, optional The polynomial order. Default is 2. * deriv : int, optional The order of the derivative to compute. Default is 0. * delta : float, optional The spacing of the samples to which the filter will be applied. Default is 1.0. * mode : str, optional The mode parameter for the savgol_filter function. Default is "interp". """ window_length = params.get("win_len", 10) poly_order = params.get("poly_order", 2) deriv = params.get("deriv", 0) delta = params.get("delta", 1.0) mode = params.get("mode", "interp") data = self.data["DFF"].sel(unit_id=id).values savgol_data = savgol_filter(data, window_length, poly_order, deriv=deriv, delta=delta, mode=mode) return savgol_data
[docs] def get_noise(self, savgol_data, id, params={}): """ Noise will be estimated by taking the absolute value difference between the dff data and savgol_smoothed signal. The noise will be then estimated with a rolling window approach where the mean, median or maximum value will be taken. Parameters ---------- savgol_data : np.array The Savitzky-Golay smoothed data. id : int The unit_id of the cell for which the noise should be calculated. """ noise_type = params.get("type", "Mean") win_len = params.get("win_len", 10) cap = params.get("cap", 0.01) if id in self.noise_values: param = noise_type + str(win_len) if param in self.noise_values[id]: return self.noise_values[id][param] else: self.noise_values[id] = {} dff = self.data["DFF"].sel(unit_id=id).values noise = abs(dff - savgol_data) if noise_type == "Mean": noise = np.convolve(noise, np.ones(win_len), 'same') / win_len elif noise_type == "Median": noise = self.rolling(noise, win_len, "median") elif noise_type == "Max": noise = self.rolling(noise, win_len, "max") noise[noise < cap] = cap # May be expensive to compute so save in noise_values if id not in self.noise_values: self.noise_values[id] = {} self.noise_values[id][noise_type + str(win_len)] = noise return noise
[docs] def get_SNR(self, savgol_data, noise): """ We will simply calculate the ratio. However, we will need to make sure that the noise is not 0. Any 0 value will be replaced with the lowest non-zero value. Parameters ---------- savgol_data : np.array The Savitzky-Golay smoothed data. noise : np.array The noise data. """ # First check if the noise is 0 if noise.sum() == 0: print("ERROR: Noise is 0") return noise # Return the noise as the SNR to indicate some sort of error snr = np.abs(savgol_data) / noise # Normalize it so that the SNR max is the savgol_data max return snr / snr.max() * savgol_data.max()
def rolling(self, data, window, rolling_type="median"): # Use the pd.Series approach to calculate the rolling window s = pd.Series(data) if rolling_type == "median": return s.rolling(window, center=True, min_periods=1).median().to_numpy() elif rolling_type == "max": return s.rolling(window, center=True, min_periods=1).max().to_numpy()
[docs] def get_transient_frames(self, unit_ids=None): ''' Get the inter-event interval. The approach is as follows: the diff of the E array will give us the rising edges. For E this means that the start of each transient will have a value of 1. We can extrapolate the inter-event by taking their corresponding frame numbers and performing another diff on them. ''' # This will contain 1s for the start of each transient if unit_ids is not None: rising_edges = self.data["E"].sel(unit_id=unit_ids).diff(dim="frame").compute() else: rising_edges = self.data["E"].diff(dim="frame").compute() # Each cell will have a 1 corresponding to the start of each transient and a # a nan value for other frames. transient_frames = rising_edges.where(rising_edges==1,drop=True) # At this stage it is the most we are able to prune the data. The rest of the # pruning will be done on a per cell basis. return transient_frames.compute()
[docs] def get_mean_iei_per_cell(self, transient_frames, cell_id, total_transients, frame_rate=None): ''' Calculate the mean inter-event interval for a single cell. The mean inter-event interval is calculated by taking the difference between the start of each transient. The mean is then calculated from the differences. Parameters ---------- start_of_transients: xr.DataArray The start of each transient for all cells taken as an output of get_mean_iei(). cell_id: int The cell for which we want to calculate the mean inter-event interval total_transients: xr.DataArray This contains the total number of transients for each cell. The number of transients should correspond to the number of 1s in the frames array. If the length of frames is 1 less than the number of transients, then we can assume that the first transient starting at frame 0 was missed by the diff operation in get_transient_frames(). ''' if cell_id not in transient_frames.coords["unit_id"]: return "N/A" frames = transient_frames.coords["frame"].where(transient_frames.sel(unit_id = cell_id) == 1, drop=True).values if total_transients.sel(unit_id = cell_id) == 1: return "N/A" if len(frames) == total_transients.sel(unit_id = cell_id).item()-1: frames = np.insert(frames, 0, 0) if frame_rate is None: return str(round(np.mean(np.diff(frames)))) else: return str(round(np.mean(np.diff(frames))/frame_rate, 3))
[docs] def get_transient_frames_iti_dict(self, unit_ids) -> tuple[dict, dict]: ''' Does the same thing as get_transient_frames() but returns two dictionaries. The first dictionary contains the unit_ids as keys and the values are the transient frames. The second dictionary contains the unit_ids as keys and the values are the inter-event intervals. Parameters ---------- unit_ids: List[int] The list of unit_ids for which the inter-event interval should be calculated. Returns ------- frame_start: dict A dictionary where the keys are the unit_ids and the values are the start of each transient. iti: dict A dictionary where the keys are the unit_ids and the values are the inter-event intervals. ''' transient_frames = self.get_transient_frames(unit_ids=unit_ids) itis = {} frame_start = {} for unit_id in unit_ids: frame_start[unit_id] = transient_frames.coords["frame"].where(transient_frames.sel(unit_id = unit_id) == 1, drop=True).values itis[unit_id] = np.diff(frame_start[unit_id]) # Calculate differences (IEIs) for each cell return frame_start, itis
def set_vector(self): values = {} for uid in self.data['unit_ids']: values[uid] = np.array([]) if 'ALP' in self.events.keys(): for key in self.events['ALP'].values: values[key] = np.r_['-1', values[key], self.events['ALP'].values[key]] if 'ILP' in self.events.keys(): for key in self.events['ILP'].values: values[key] = np.r_['-1', values[key], self.events['ILP'].values[key]] if 'RNF' in self.events.keys(): for key in self.events['RNF'].values: values[key] = np.r_['-1', values[key], self.events['RNF'].values[key]] if 'ALP_Timeout' in self.events.keys(): for key in self.events['ALP_Timeout'].values: values[key] = np.r_['-1', values[key], self.events['ALP_Timeout'].values[key]] # If no events in this period for uid in self.data['unit_ids']: if values[uid].size == 0: values[uid] = self.data['C'].sel(unit_id = int(uid)).values self.values = values def set_distance_metric(self, distance_metirc:str): self.distance_metric = distance_metirc def set_group(self, group_type: str): self.group = group_type def set_outliers(self, outliers: List[int]): self.outliers_list = outliers def set_no_of_clusters(self, number : int): self.no_of_clusters = number def compute_clustering(self): self.cellClustering = CellClustering(self.values, self.outliers_list, self.data["A"], distance_metric = self.distance_metric) self.linkage_data = self.cellClustering.linkage_data self.clustering_result = self.cellClustering.visualize_clusters(self.no_of_clusters) def get_vis_info(self): image = self.clustering_result["all"]["image"] if "all" in self.clustering_result else self.clustering_result["basic"]["image"] return self.mouseID, self.session, self.day, self.group, image def get_dendrogram(self, ax): self.cellClustering.visualize_dendrogram(color_threshold =self.linkage_data[(self.no_of_clusters-1),2] ,ax=ax)
[docs] def centroid(self, A: xr.DataArray, verbose=False) -> pd.DataFrame: """ Compute centroids of spatial footprint of each cell. Parameters ---------- A : xr.DataArray Input spatial footprints. verbose : bool, optional Whether to print message and progress bar. By default `False`. Returns ------- cents_df : pd.DataFrame Centroid of spatial footprints for each cell. Has columns "unit_id", "height", "width" and any other additional metadata dimension. """ def rel_cent(im): im_nan = np.isnan(im) if im_nan.all(): return np.array([np.nan, np.nan]) if im_nan.any(): im = np.nan_to_num(im) cent = np.array(center_of_mass(im)) return cent / im.shape gu_rel_cent = darr.gufunc( rel_cent, signature="(h,w)->(d)", output_dtypes=float, output_sizes=dict(d=2), vectorize=True, ) cents = xr.apply_ufunc( gu_rel_cent, A.chunk(dict(height=-1, width=-1)), input_core_dims=[["height", "width"]], output_core_dims=[["dim"]], dask="allowed", ).assign_coords(dim=["height", "width"]) if verbose: print("computing centroids") with ProgressBar(): cents = cents.compute() cents_df = ( cents.rename("cents") .to_series() .dropna() .unstack("dim") .rename_axis(None, axis="columns") .reset_index() ) h_rg = (A.coords["height"].min().values, A.coords["height"].max().values) w_rg = (A.coords["width"].min().values, A.coords["width"].max().values) cents_df["height"] = cents_df["height"] * (h_rg[1] - h_rg[0]) + h_rg[0] cents_df["width"] = cents_df["width"] * (w_rg[1] - w_rg[0]) + w_rg[0] return cents_df
[docs] def centroid_max(self, A: xr.DataArray, verbose=False) -> pd.DataFrame: """ Compute the centroid by taking the maximum value in the image. Nearly the same as centroid() however it is looks better in the 3D visualizations Parameters ---------- A : xr.DataArray Input spatial footprints. verbose : bool, optional Whether to print message and progress bar. By default `False`. Returns ------- cents_df : pd.DataFrame Centroid Max of spatial footprints for each cell. Has columns "unit_id", "height", "width" and any other additional metadata dimension. """ def max_cent(im): im_nan = np.isnan(im) if im_nan.all(): return np.array([np.nan, np.nan]) if im_nan.any(): im = np.nan_to_num(im) max_index_flat = np.argmax(im) max_index = np.unravel_index(max_index_flat, im.shape) return np.array(max_index) gu_max_cent = darr.gufunc( max_cent, signature="(h,w)->(d)", output_dtypes=int, output_sizes=dict(d=2), vectorize=True, ) cents = xr.apply_ufunc( gu_max_cent, A.chunk(dict(height=-1, width=-1)), input_core_dims=[["height", "width"]], output_core_dims=[["dim"]], dask="allowed", ).assign_coords(dim=["height", "width"]) if verbose: print("computing centroids") with ProgressBar(): cents = cents.compute() cents_df = ( cents.rename("cents") .to_series() .dropna() .unstack("dim") .rename_axis(None, axis="columns") .reset_index() ) return cents_df
[docs] def update_and_save_E(self, unit_id: int, spikes: Union[list, np.ndarray], update_type: str = "Accept Incoming Only"): """ Update the E array with the final peaks and save it to the minian file. Parameters ---------- unit_id : int The unit_id of the cell for which the E array should be updated. spikes : Union[list, np.ndarray] The final peaks that should be added to the E array. update_type : str, optional The type of update that should be performed. The options are: * Accept Incoming Only : Only accept the incoming spikes and ignore any overlapping spikes. * Accept Overlapping Only : Accept all spikes including overlapping spikes. * Accept All : Accept all spikes and set the E array to 1 for all the spikes. """ # First convert final peaks into a numpy array E = self.data['E'] dtype = E.dtype if isinstance(spikes, list): new_e = np.zeros(E.shape[1], dtype=dtype) for spike in spikes: new_e[spike[0]:spike[1]] = 1 else: new_e = spikes.astype(dtype) E.load() # Load into memory if update_type == "Accept Overlapping Only": new_e *= E.sel(unit_id=unit_id).values elif update_type == "Accept All": new_e += E.sel(unit_id=unit_id).values new_e[new_e > 0] = 1 E.loc[dict(unit_id=unit_id)] = new_e # Now save the E array to disk overwrite_xarray(E, self.cnmf_path) self.changed_events = True
def clear_E(self, unit_id): E = self.data['E'] E.load() E.loc[dict(unit_id=unit_id)] = 0 overwrite_xarray(E, self.cnmf_path) self.changed_events = True
[docs] def backup_data(self, name: str): """ Backup a specified data array to the backup folder. Parameters ---------- name : str The name of the data array to backup. """ data = self.data[name] data.load() # Save to backup folder but first check if it exists if not os.path.exists(os.path.join(self.cnmf_path, "backup")): os.makedirs(os.path.join(self.cnmf_path, "backup")) t = time.localtime() current_time = time.strftime("%m_%d_%H_%M_%S", t) overwrite_xarray(data, os.path.join(self.cnmf_path, "backup", f"{name}_" + current_time))
def remove_from_E(self, clear_selected_events_local: Dict[int, List[int]]): E = self.data['E'] E.load() for unit_id, x_values in clear_selected_events_local.items(): events = E.sel(unit_id=unit_id).values events[x_values] = 0 E.loc[dict(unit_id=unit_id)] = events overwrite_xarray(E, self.cnmf_path) self.changed_events = True def add_to_E(self, add_selected_events_local: Dict[int, List[int]]): E = self.data['E'] E.load() for unit_id, x_values in add_selected_events_local.items(): events = E.sel(unit_id=unit_id) events[x_values] = 1 E.loc[dict(unit_id=unit_id)] = events overwrite_xarray(E, self.cnmf_path) self.changed_events = True
[docs] def reject_cells(self, cells: List[int]): """ Set the good_cells array to 0 for the cells in the list. """ E = self.data['E'] E.load() E['good_cells'].loc[dict(unit_id=cells)] = 0 E['verified'].loc[dict(unit_id=cells)] = 0 overwrite_xarray(self.data['E'], self.cnmf_path)
def approve_cells(self, cells: List[int]): E = self.data['E'] E.load() E['good_cells'].loc[dict(unit_id=cells)] = 1 overwrite_xarray(self.data['E'], self.cnmf_path) def update_verified(self, cells: List[int], force_verified: bool = False): E = self.data['E'] E.load() for cell in cells: if force_verified: E['verified'].loc[dict(unit_id=cell)] = 1 else: E['verified'].loc[dict(unit_id=cell)] = (E['verified'].loc[dict(unit_id=cell)].values.item() + 1) % 2 overwrite_xarray(self.data['E'], self.cnmf_path) def prune_non_verified(self, cells: set): # Keep only verified cells in cells verified_unit_ids = self.get_verified_cells() is_list = False if type(cells) == list: cells = set(cells) is_list = True if is_list: return list(cells.intersection(verified_unit_ids)) else: return cells.intersection(verified_unit_ids)
[docs] def prune_rejected_cells(self, cells): """ Prune the cells that have been rejected from the list of cells. """ E = self.data['E'] E.load() return [cell for cell in cells if E.sel(unit_id=cell)['good_cells'].values.item() == 1]
def get_verified_cells(self): all_unit_ids = self.data['E'].unit_id.values verified_idxs = self.data['E'].verified.values.astype(int) verified_unit_ids = all_unit_ids[verified_idxs==1] return verified_unit_ids def get_good_cells(self): all_unit_ids = self.data['E'].unit_id.values good_idxs = self.data['E'].good_cells.values.astype(int) good_unit_ids = all_unit_ids[good_idxs==1] return good_unit_ids
[docs] def check_E(self): """ Check if the E xarray exists and if not create it. """ if self.data['E'] is None: print("Creating E array") E = xr.DataArray( np.zeros(self.data['C'].shape), dims=["unit_id", "frame"], coords=dict( unit_id=self.data['unit_ids'], frame=self.data['C'].coords["frame"], ), name="E" ) E = E.assign_coords(good_cells=("unit_id", np.ones(len(self.data['unit_ids']))), verified=("unit_id", np.zeros(len(self.data['unit_ids'])))) E.coords['timestamp(ms)'] = self.data['timestamp(ms)'] self.data['E'] = overwrite_xarray(E, self.cnmf_path, retrieve=True) # For backwards compatibility check if the verified values exist and if not create them elif "verified" not in self.data['E'].coords: self.data['E'] = self.data['E'].assign_coords(verified=("unit_id", np.zeros(len(self.data['unit_ids'])))) overwrite_xarray(self.data['E'], self.cnmf_path)
[docs] def check_DFF(self): """ Check if the DFF xarray exists and if not create it. """ if self.data['DFF'] is None: print("Creating DFF array. Sit tight this could take a while.") # Convert the data into caiman format A, b, C, f, YrA = minian_to_caiman(self.data['A'], self.data['b'], self.data['C'], self.data['f'], self.data['YrA']) dff_array = detrend_df_f(A, b, C, f, YrA, flag_auto=False) DFF = xr.DataArray( dff_array, dims=["unit_id", "frame"], coords=dict( unit_id=self.data['unit_ids'], frame=self.data['C'].coords["frame"], ), name="DFF" ).chunk(dict(frame=-1, unit_id="auto")) self.data['DFF'] = overwrite_xarray(DFF, self.cnmf_path, retrieve=True)
[docs] def check_essential_data(self): ''' Create a list of essential data that is required for the analysis. ''' essential_data = ["A", "C", "S", "b", "f", "YrA"] got_data = True for data_type in essential_data: if self.data[data_type] is None: print("Missing essential data: %s" % data_type) got_data = False return got_data
[docs] def save_cell_groups(self) -> str: """ Save current cell groups to session_group_ids.json in the session's data folder. Called automatically whenever groups are added/removed in the GUI. Converts internal format {cell_id: [group_ids]} to JSON format {group_name: [cell_ids]}. Returns ------- str Path to the saved file """ import json # Convert from internal format to JSON format # Internal: {cell_id: [group_id1, group_id2]} # JSON: {group_name: [cell_id1, cell_id2]} groups_json = {} for cell_id, group_ids in self.cell_ids_to_groups.items(): for group_id in group_ids: group_name = f"Group {group_id}" if group_name not in groups_json: groups_json[group_name] = [] groups_json[group_name].append(int(cell_id)) # Sort cell IDs in each group for consistency for group_name in groups_json: groups_json[group_name].sort() # Save to JSON file in the parent folder (session folder, not inside minian) parent_path = os.path.dirname(self.cnmf_path) json_path = os.path.join(parent_path, "session_group_ids.json") with open(json_path, 'w') as f: json.dump(groups_json, f, indent=2) return json_path
[docs] def load_cell_groups(self): """ Load cell groups from session_group_ids.json if it exists. Called during __init__ to auto-load existing groups when opening a session. Converts JSON format {group_name: [cell_ids]} to internal format {cell_id: [group_ids]}. Populates self.cell_ids_to_groups. """ import json parent_path = os.path.dirname(self.cnmf_path) json_path = os.path.join(parent_path, "session_group_ids.json") if not os.path.exists(json_path): return try: with open(json_path, 'r') as f: groups_json = json.load(f) # Convert from JSON format to internal format # JSON: {group_name: [cell_id1, cell_id2]} # Internal: {cell_id: [group_id1, group_id2]} self.cell_ids_to_groups = {} for group_name, cell_ids in groups_json.items(): # Extract group ID from group name (e.g., "Group 1" -> "1") if group_name.startswith("Group "): group_id = group_name[6:] # Remove "Group " prefix else: group_id = group_name for cell_id in cell_ids: if cell_id not in self.cell_ids_to_groups: self.cell_ids_to_groups[cell_id] = [group_id] else: self.cell_ids_to_groups[cell_id].append(group_id) except (json.JSONDecodeError, KeyError) as e: print(f"Error loading cell groups from {json_path}: {e}") self.cell_ids_to_groups = {}
[docs] def add_cell_id_group(self, cell_ids: List, group_id: str): """ Allocate specific cell ids to a group id This function will allocate the cell ids to a group id. This will be stored in a dictionary where the key is the cell id and the value is a set of group ids. If group_id is an empty string, we will allocate a number as the group id. Parameters ---------- cell_ids : list List of cell ids to allocate to the group group_id : str """ if group_id == "": # First find the lowest available group id current_ids = set(self.cell_ids_to_groups.values()) group_id = 1 while group_id in current_ids: group_id += 1 for id in cell_ids: if id not in self.cell_ids_to_groups: self.cell_ids_to_groups[id] = [group_id] else: self.cell_ids_to_groups[id] += [group_id] # Auto-save after adding group self.save_cell_groups()
[docs] def remove_cell_id_group(self, cell_id_group: List): """ Remove the cell ids from the group id. Parameters ---------- cell_id_group : list List of cell ids to remove from the group id """ for id in cell_id_group: if id in self.cell_ids_to_groups: del self.cell_ids_to_groups[id] # Auto-save after removing group self.save_cell_groups()
def get_group_ids(self): all_group_ids = [] for group_ids in self.cell_ids_to_groups.values(): all_group_ids += group_ids return np.unique(all_group_ids) def get_video_interval(self): timestamps = self.data['timestamp(ms)'].values # Take first 100 frames and calculate the frame rate elapsed_time = timestamps[100] - timestamps[0] frame_time = int(elapsed_time / 100) return frame_time def frame_to_time(self, frame): timestamp = self.data['timestamp(ms)'].values[frame].item() # Convert to 00:00:00.00 format seconds = timestamp / 1000 isec, fsec = divmod(round(seconds*100), 100) return "{}.{:02.0f}".format(datetime.timedelta(seconds=isec), fsec)
[docs] def get_cell_ids(self, group_id, verified=False): """ Get the cell ids for the group id. Parameters ---------- group_id : str The group id to extract the cell ids from. verified : bool If True, only extract the verified cells. """ if group_id == "All Cells": unit_ids = self.data['E'].unit_id.values elif group_id == "Verified Cells": all_unit_ids = self.data['E'].unit_id.values verified_idxs = self.data['E'].verified.values.astype(int) unit_ids = all_unit_ids[verified_idxs==1] else: if "Group" not in group_id: raise ValueError("Invalid group id") # Extract the group id from the string group_id = group_id[6:] # Find the corresponding cell ids unit_ids = [key for key, value in self.cell_ids_to_groups.items() if group_id in value] if verified: all_unit_ids = self.data['E'].unit_id.values verified_cells = self.data['E'].verified.values verified_cells = all_unit_ids[verified_cells==1] intersection = np.intersect1d(unit_ids, verified_cells) unit_ids = list(intersection) unit_ids.sort() return unit_ids
[docs] def merge_cells(self, cell_ids: List[List[int]]): """ Merge the cells in the list of cell ids. By averaging both their spatial footprints and temporal activities. The previous C, S, A, YrA, DFF and E arrays will be first backed up before the merge is performed. The E array will drop the cell ids that are not in the list of cell ids to merge and it will change the verified status to 0 for the merged cell id. Parameters ---------- cell_ids : list List of cell ids to merge. """ # For each group of cells within the list, we'll take the lowest cell id as the main cell id # and merge the rest of the cells into it. cell_mapping = {} for group in cell_ids: min_id = min(group) for id in group: if id != min_id: cell_mapping[id] = min_id unit_labels = self.data['unit_ids'] for i in range(len(unit_labels)): if unit_labels[i] in cell_mapping: unit_labels[i] = cell_mapping[unit_labels[i]] # Backup stuff here self.backup_data("A") self.backup_data("C") self.backup_data("S") self.backup_data("YrA") self.backup_data("DFF") self.backup_data("E") # Merge the spatial footprints A_merge = ( self.data["A"].assign_coords(unit_labels=("unit_id", unit_labels)) .groupby("unit_labels") .mean("unit_id") .rename(unit_labels="unit_id") ) C_merge = ( self.data["C"].assign_coords(unit_labels=("unit_id", unit_labels)) .groupby("unit_labels") .mean("unit_id") .rename(unit_labels="unit_id") ) S_merge = ( self.data["S"].assign_coords(unit_labels=("unit_id", unit_labels)) .groupby("unit_labels") .mean("unit_id") .rename(unit_labels="unit_id") ) YrA_merge = ( self.data["YrA"].assign_coords(unit_labels=("unit_id", unit_labels)) .groupby("unit_labels") .mean("unit_id") .rename(unit_labels="unit_id") ) DFF_merge = ( self.data["DFF"].assign_coords(unit_labels=("unit_id", unit_labels)) .groupby("unit_labels") .mean("unit_id") .rename(unit_labels="unit_id") ) # For E we will first go through the keys and update verified to 0 for the merged cells for key in cell_mapping.keys(): self.data["E"]["verified"].loc[dict(unit_id=key)] = 0 # Get the all values of cell_mapping into a list and drop them from E drop_keys = list(cell_mapping.keys()) self.data["E"] = self.data["E"].drop_sel(unit_id=drop_keys) # Save the new arrays self.data["A"] = overwrite_xarray(A_merge, self.cnmf_path, retrieve=True) self.data["C"] = overwrite_xarray(C_merge, self.cnmf_path, retrieve=True) self.data["S"] = overwrite_xarray(S_merge, self.cnmf_path, retrieve=True) self.data["YrA"] = overwrite_xarray(YrA_merge, self.cnmf_path, retrieve=True) self.data["E"] = overwrite_xarray(self.data["E"], self.cnmf_path, retrieve=True) self.data["DFF"] = overwrite_xarray(DFF_merge, self.cnmf_path, retrieve=True)
[docs]class CellClustering: """ Cell clustering class. This class is used to cluster cells based on their temporal activity, using FFT and agglomerative clustering. Parameters ---------- section : dict A dictionary containing the cell ids as keys and the temporal activity as values. outliers_list : list A list of cell ids that should be excluded from the clustering. A : xr.DataArray The spatial footprints of the cells. fft : bool, optional Whether to use FFT to compute the PSD. By default `True`. distance_metric : str, optional The distance metric to use for the clustering. The options are: - euclidean - cosine Attributes ---------- A : xr.DataArray The spatial footprints of the cells. psd_list_pre : dict A dictionary containing the cell ids as keys and the PSD as values. psd_list : list A list of the PSD values. outliers_list : list A list of cell ids that should be excluded from the clustering. special_unit : list A list of cell ids that have no activity. distance_metric : str The distance metric to use for the clustering. signals : dict A dictionary containing the cell ids as keys and the temporal activity as values. linkage_data : np.array The linkage data for the clustering. dendro : dict The dendrogram data. cluster_indices : np.array The cluster indices. """ def __init__( self, section: Optional[dict] = None, outliers_list: List[int] = [], A: Optional[xr.DataArray] = None, fft: bool = True, distance_metric: str = 'euclidean' ): self.A = A self.psd_list_pre = {} self.psd_list = [] self.outliers_list = outliers_list self.special_unit = [] self.distance_metric = distance_metric self.signals = {} for values in section.keys(): if values not in self.outliers_list: self.signals[values] = section[values] if fft: for unit_id in self.signals: self.compute_psd(unit_id) # compute psd for each unit else: self.psd_list = [self.signals[unit_id] for unit_id in self.signals] # Compute agglomerative clustering if self.distance_metric == 'euclidean': for unit_id in self.psd_list_pre: self.psd_list.append(self.psd_list_pre[unit_id]) self.linkage_data = linkage(self.psd_list, method='average', metric='euclidean') elif self.distance_metric == 'cosine': for (unit_id,unit_id) in zip(self.signals,self.psd_list_pre): if(all(value == 0 for value in self.psd_list_pre[unit_id]) ==True): self.special_unit.append(unit_id) else: self.psd_list.append(self.psd_list_pre[unit_id]) self.linkage_data = linkage(self.psd_list, method='average', metric='cosine')
[docs] def compute_psd(self, unit: int): """ Compute the power spectral density of the signal for a given cell. Parameters ---------- unit : int The cell id. """ val = self.signals[unit] f, psd = welch(val, fs=1./30, window='hann', nperseg=256, detrend='constant') self.psd_list_pre[unit] = psd
[docs] def visualize_dendrogram(self, color_threshold=None, ax=None): """ Apply dendrogram from scipy.cluster.hierarchy and save result to class attribute. Parameters ---------- color_threshold : float, optional The color threshold for the dendrogram. By default `None`. ax : matplotlib.axes.Axes, optional The axes to plot the dendrogram. By default `None`. Returns ------- dendro : dict The dendrogram data. """ self.dendro = dendrogram(self.linkage_data,labels=list(self.signals.keys()), color_threshold=color_threshold, ax=ax) return self.dendro
[docs] def visualize_clusters(self, t): """ Visualize the clusters by assigning a color to each cluster and looking up the corresponding footprint of each cell. Parameters ---------- t : int The number of clusters to create. Returns ------- cluster_result : dict A dictionary containing the cluster results. The keys are the cluster indices and the values are dictionaries containing the cell ids and the image of the cluster. """ self.cluster_indices = fcluster(self.linkage_data, t=t, criterion='maxclust') viridis = cm.get_cmap('jet', self.cluster_indices.max()+1) image_shape = self.A[list(self.A.keys())[0]].values.shape final_image = np.zeros((image_shape[0], image_shape[1], 3)) cluster_result = {} cluster_result["all"] = {} cluster_result["all"]["ids"] = [] cluster_result["all"]["image"] = final_image.copy() for i in range(t+1): cluster_result[i] = {} cluster_result[i]["ids"] = [] cluster_result[i]["image"] = final_image.copy() for idx, cluster in enumerate(self.cluster_indices): cluster_result[cluster]["ids"].append(list(self.signals.keys())[idx]) cluster_result["all"]["ids"].append(list(self.signals.keys())[idx]) cluster_result[cluster]["image"] += np.stack((self.A[list(self.signals.keys())[idx]].values,)*3, axis=-1) * viridis(cluster)[:3] final_image += np.stack((self.A[list(self.signals.keys())[idx]].values,)*3, axis=-1) * viridis(cluster)[:3] cluster_result["all"]["image"] = final_image return cluster_result
[docs] def visualize_clusters_color(self): """ Slightly different approach to visualize the clusters. This will color the cells based on the cluster the dendrogram results. Returns ------- matplotlib.image.AxesImage The image of the clustered cells. """ viridis = cm.get_cmap('viridis', len(np.unique(self.dendro["leaves_color_list"]))) color_mapping= {} for i, leaf in enumerate(self.dendro['leaves']): color_mapping[leaf] = int(self.dendro['leaves_color_list'][i][1]) - 1 # Convert to int image_shape = self.A[list(self.A.keys())[0]].values.shape final_image = np.zeros((image_shape[0], image_shape[1], 3)) for idx in self.dendro['leaves']: final_image += np.stack((self.A[list(self.A.keys())[idx]].values,)*3, axis=-1) * viridis(color_mapping[idx])[:3] return plt.imshow(final_image)