Skip to content

Decorators

genie_tooling.decorators

Functions

tool

tool(func: Callable) -> Callable

Decorator to transform a Python function into a Genie-compatible Tool.

Source code in src/genie_tooling/decorators.py
def tool(func: Callable) -> Callable:
    """Decorator to transform a Python function into a Genie-compatible Tool."""
    globalns = getattr(func, "__globals__", {})
    try:
        type_hints = get_type_hints(func, globalns=globalns)
    except NameError as e:
        try:
            type_hints = get_type_hints(func)
        except Exception:
            type_hints = {}
            print(
                f"Warning: Could not fully resolve type hints for {func.__name__} due to {e}. Schemas might be incomplete."
            )

    sig = inspect.signature(func)
    docstring = inspect.getdoc(func) or ""
    main_description = docstring.split("\n\n")[0].strip() or f"Executes the '{func.__name__}' tool."
    param_descriptions_from_doc = _parse_docstring_for_params(docstring)
    properties: Dict[str, Any] = {}
    required_params: List[str] = []

    for name, param in sig.parameters.items():
        if name in ("self", "cls") or param.kind in (
            inspect.Parameter.VAR_POSITIONAL,
            inspect.Parameter.VAR_KEYWORD,
        ):
            continue

        param_py_type_hint = type_hints.get(name, Any)
        if isinstance(param_py_type_hint, str):
            try:
                param_py_type_hint = ForwardRef(param_py_type_hint)._evaluate(  # type: ignore
                    globalns, {}, recursive_guard=frozenset()
                )
            except Exception:
                param_py_type_hint = Any

        is_optional_from_union_type = False
        actual_param_type_for_schema = param_py_type_hint
        origin = get_origin(param_py_type_hint)
        args = get_args(param_py_type_hint)

        if origin is Union and type(None) in (args or []):
            is_optional_from_union_type = True
            if args:
                non_none_args = [t for t in args if t is not type(None)]
                if len(non_none_args) == 1:
                    actual_param_type_for_schema = non_none_args[0]
                else:
                    actual_param_type_for_schema = Union[tuple(non_none_args)]  # Keep as Union of non-None

        param_schema_def = _map_type_to_json_schema(actual_param_type_for_schema)
        # --- FIX: Default to 'string' if schema is empty (from Any type) ---
        if not param_schema_def:
            param_schema_def["type"] = "string"

        param_schema_def["description"] = param_descriptions_from_doc.get(name, f"Parameter '{name}'.")

        # Handle optionality by adding "null" to the type list
        if is_optional_from_union_type:
            if "type" in param_schema_def and isinstance(param_schema_def["type"], str):
                param_schema_def["type"] = [param_schema_def["type"], "null"]
            elif "type" in param_schema_def and isinstance(param_schema_def["type"], list):
                if "null" not in param_schema_def["type"]:
                    param_schema_def["type"].append("null")

        if param.default is inspect.Parameter.empty:
            if not is_optional_from_union_type and name not in FRAMEWORK_INJECTED_PARAMS:
                required_params.append(name)
        else:
            param_schema_def["default"] = param.default

        properties[name] = param_schema_def

    input_schema: Dict[str, Any] = {"type": "object", "properties": properties}
    if required_params:
        input_schema["required"] = required_params

    return_py_type_hint = type_hints.get("return", Any)
    if isinstance(return_py_type_hint, str):
        try:
            return_py_type_hint = ForwardRef(return_py_type_hint)._evaluate(  # type: ignore
                globalns, {}, recursive_guard=frozenset()
            )
        except Exception:
            return_py_type_hint = Any

    actual_return_type_for_schema = return_py_type_hint
    ret_origin = get_origin(return_py_type_hint)
    ret_args = get_args(return_py_type_hint)
    if ret_origin is Union and type(None) in (ret_args or []):
        if ret_args:
            actual_return_type_for_schema = next((t for t in ret_args if t is not type(None)), Any)

    output_schema_prop_def = _map_type_to_json_schema(actual_return_type_for_schema)
    if not output_schema_prop_def:
        output_schema_prop_def = {"type": "object"}

    output_schema: Dict[str, Any] = {"type": "object", "properties": {"result": output_schema_prop_def}}
    if (
        output_schema_prop_def.get("type") != "null"
        and not (
            isinstance(output_schema_prop_def.get("type"), list)
            and "null" in output_schema_prop_def["type"]
            and len(output_schema_prop_def["type"]) == 1
        )
        and output_schema_prop_def != {}
    ):
        output_schema["required"] = ["result"]

    tool_metadata = {
        "identifier": func.__name__,
        "name": func.__name__.replace("_", " ").title(),
        "description_human": main_description,
        "description_llm": main_description,
        "input_schema": input_schema,
        "output_schema": output_schema,
        "key_requirements": [],
        "tags": ["decorated_tool"],
        "version": "1.0.0",
        "cacheable": False,
    }

    if inspect.iscoroutinefunction(func):

        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            return await func(*args, **kwargs)

    else:

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

    wrapper._tool_metadata_ = tool_metadata  # type: ignore
    wrapper._original_function_ = func  # type: ignore
    return wrapper