trial_matching_algorithm
Algorithm for matching patients to clinical trials.
This algorithm takes a per-patient DataFrame (with ICD-10 condition codes and demographics) and a trials DataFrame (from SqlQuery) and computes a match score for each patient-trial pair. Score is normalized to [0, 1]: 1 = all inclusion codes matched and no exclusion codes matched; 0 = vice versa. Returns the top N trials per patient.
Classes
TrialMatchingAlgorithm
class TrialMatchingAlgorithm( *, datastructure: Optional[DataStructure] = None, top_n: int = 5, max_distance_km: Optional[float] = None, default_country: str = 'us',):Algorithm for matching patients to clinical trials via ICD-10 codes.
Uses patient condition codes (ICD-10 only) and trial inclusion/exclusion codes to compute match scores. Score is normalized to [0, 1]: 1 = perfect match (all inclusion matched, no exclusion matched); 0 = worst.
Patient postal code and country enable distance filtering via pgeocode.
Trial matching logic (exact sequence):
-
Group trials by nct_id: aggregate inclusion/exclusion ICD-10 codes, demographics (age range, sex_eligible), and site locations from trials_df.
-
For each patient row: a. Build patient profile: extract ICD-10 codes from Conditions, compute age from date_of_birth, normalize gender and country.
-
For each patient, match against all trials. For each trial, in order: a. Demographic check: skip if patient age outside trial min/max, or sex does not match sex_eligible (when restricted). b. ICD-10 code matching (done before distance for performance): skip trials with no inclusion ICD-10 codes; compute matched_inclusion = patient_icd10 ∩ trial inclusion codes; skip if matched_inclusion is empty; compute matched_exclusion. c. Distance check (if max_distance_km set): skip if nearest trial site is farther than max_distance_km; skip trials with no sites in same country as patient. d. Score = (len(matched_inclusion) - len(matched_exclusion)
- n_excl) / (n_incl + n_excl), range [0, 1], where n_incl/n_excl are trial inclusion/exclusion code counts. e. Append result with nct_id, title, score, nearest_site_km, explanation.
-
Sort matched trials by score descending; return top N per patient.
Arguments
- **
**kwargs**: Additional keyword arguments. datastructure: The data structure to use for the algorithm.default_country: ISO alpha-2 code used when patient or site country is missing. Defaults to "us". Injected into country normalization.max_distance_km: Optional maximum distance in km to the nearest trial site. Ignored whenNone.top_n: Number of top trials to return per patient. Defaults to 5.
Attributes
class_name: The name of the algorithm class.default_country: Default when country is missing (from task/algorithm args).fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshmallow field type. (e.g. fields_dict ={"class_name": fields.Str()}).max_distance_km: Maximum allowed distance to trial site in km, or None.nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry})top_n: Number of top trials to return per patient.
Variables
- static
fields_dict : ClassVar[dict[str, marshmallow.fields.Field]]
Methods
create
def create(self, role: Union[str, Role], **kwargs: Any) ‑> Any:Create an instance representing the role specified.
modeller
def modeller( self, *, context: ProtocolContext, **kwargs: Any,) ‑> NoResultsModellerAlgorithm:Returns the modeller side of the TrialMatchingAlgorithm.
worker
def worker( self, *, context: ProtocolContext, **kwargs: Any,) ‑> bitfount.federated.algorithms.ehr.trial_matching_algorithm._WorkerSide:Returns the worker side of the TrialMatchingAlgorithm.