Skip to content

signature_checker

check_compatible(function, func_name, expected_params)

Check if a function is compatible with the expected signature.

Args: - function (callable): The function to check. - func_name (str): The name of the function. - expected_params (list): A list of dictionaries containing the expected parameter name, type, and position.

Returns: - function (callable): The original function if it's compatible.

Raises: - AttributeError: If the function is not compatible with the expected signature.

Source code in src/lmflux/utils/signature_checker.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def check_compatible(function: callable, func_name: str, expected_params: list):
    """
    Check if a function is compatible with the expected signature.

    Args:
    - function (callable): The function to check.
    - func_name (str): The name of the function.
    - expected_params (list): A list of dictionaries containing the expected parameter name, type, and position.

    Returns:
    - function (callable): The original function if it's compatible.

    Raises:
    - AttributeError: If the function is not compatible with the expected signature.
    """
    param_count = 0
    params_match = [False] * len(expected_params)
    signature = inspect.signature(function)
    sigs = {
        param.name: param.annotation
        for param in signature.parameters.values()
    }

    for param_name, py_type in sigs.items():
        for i, expected_param in enumerate(expected_params):
            if (param_name == expected_param['name'] and 
                (expected_param['type'] is None or py_type == expected_param['type']) and 
                param_count == expected_param['position']):
                params_match[i] = True
        param_count += 1

    if all(params_match) and param_count == len(expected_params):
        return function

    correct_sig = f"def {function.__name__}(" + ", ".join([f"{param['name']}: {param['type'].__name__ if param['type'] else 'any'}" for param in expected_params]) + "):..."
    raise AttributeError(f"{func_name} must be defined as {correct_sig}")