def tool(func: Callable) -> Callable:
"""Decorator to transform a Python function into a Genie-compatible Tool."""
globalns = getattr(func, "__globals__", {})
try:
type_hints = get_type_hints(func, globalns=globalns)
except NameError as e:
try:
type_hints = get_type_hints(func)
except Exception:
type_hints = {}
print(
f"Warning: Could not fully resolve type hints for {func.__name__} due to {e}. Schemas might be incomplete."
)
sig = inspect.signature(func)
docstring = inspect.getdoc(func) or ""
main_description = docstring.split("\n\n")[0].strip() or f"Executes the '{func.__name__}' tool."
param_descriptions_from_doc = _parse_docstring_for_params(docstring)
properties: Dict[str, Any] = {}
required_params: List[str] = []
for name, param in sig.parameters.items():
if name in ("self", "cls") or param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
param_py_type_hint = type_hints.get(name, Any)
if isinstance(param_py_type_hint, str):
try:
param_py_type_hint = ForwardRef(param_py_type_hint)._evaluate( # type: ignore
globalns, {}, recursive_guard=frozenset()
)
except Exception:
param_py_type_hint = Any
is_optional_from_union_type = False
actual_param_type_for_schema = param_py_type_hint
origin = get_origin(param_py_type_hint)
args = get_args(param_py_type_hint)
if origin is Union and type(None) in (args or []):
is_optional_from_union_type = True
if args:
non_none_args = [t for t in args if t is not type(None)]
if len(non_none_args) == 1:
actual_param_type_for_schema = non_none_args[0]
else:
actual_param_type_for_schema = Union[tuple(non_none_args)] # Keep as Union of non-None
param_schema_def = _map_type_to_json_schema(actual_param_type_for_schema)
# --- FIX: Default to 'string' if schema is empty (from Any type) ---
if not param_schema_def:
param_schema_def["type"] = "string"
param_schema_def["description"] = param_descriptions_from_doc.get(name, f"Parameter '{name}'.")
# Handle optionality by adding "null" to the type list
if is_optional_from_union_type:
if "type" in param_schema_def and isinstance(param_schema_def["type"], str):
param_schema_def["type"] = [param_schema_def["type"], "null"]
elif "type" in param_schema_def and isinstance(param_schema_def["type"], list):
if "null" not in param_schema_def["type"]:
param_schema_def["type"].append("null")
if param.default is inspect.Parameter.empty:
if not is_optional_from_union_type and name not in FRAMEWORK_INJECTED_PARAMS:
required_params.append(name)
else:
param_schema_def["default"] = param.default
properties[name] = param_schema_def
input_schema: Dict[str, Any] = {"type": "object", "properties": properties}
if required_params:
input_schema["required"] = required_params
return_py_type_hint = type_hints.get("return", Any)
if isinstance(return_py_type_hint, str):
try:
return_py_type_hint = ForwardRef(return_py_type_hint)._evaluate( # type: ignore
globalns, {}, recursive_guard=frozenset()
)
except Exception:
return_py_type_hint = Any
actual_return_type_for_schema = return_py_type_hint
ret_origin = get_origin(return_py_type_hint)
ret_args = get_args(return_py_type_hint)
if ret_origin is Union and type(None) in (ret_args or []):
if ret_args:
actual_return_type_for_schema = next((t for t in ret_args if t is not type(None)), Any)
output_schema_prop_def = _map_type_to_json_schema(actual_return_type_for_schema)
if not output_schema_prop_def:
output_schema_prop_def = {"type": "object"}
output_schema: Dict[str, Any] = {"type": "object", "properties": {"result": output_schema_prop_def}}
if (
output_schema_prop_def.get("type") != "null"
and not (
isinstance(output_schema_prop_def.get("type"), list)
and "null" in output_schema_prop_def["type"]
and len(output_schema_prop_def["type"]) == 1
)
and output_schema_prop_def != {}
):
output_schema["required"] = ["result"]
tool_metadata = {
"identifier": func.__name__,
"name": func.__name__.replace("_", " ").title(),
"description_human": main_description,
"description_llm": main_description,
"input_schema": input_schema,
"output_schema": output_schema,
"key_requirements": [],
"tags": ["decorated_tool"],
"version": "1.0.0",
"cacheable": False,
}
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
else:
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper._tool_metadata_ = tool_metadata # type: ignore
wrapper._original_function_ = func # type: ignore
return wrapper