nnsight.intervention#

This module contains logic to interleave a computation graph (an intervention graph) with the computation graph of a model.

The InterventionProxy class extends the functionality of a base nnsight.tracing.Proxy.Proxy object and makes it easier for users to interact with.

intervene() is the entry hook into the models computation graph in order to interleave an intervention graph.

The HookModel provides a context manager for adding input and output hooks to modules and removing them upon context exit.

class nnsight.intervention.HookHandler(model: Module, module_keys: List[str], input_hook: Callable | None = None, output_hook: Callable | None = None)[source]#

Context manager that applies input and/or output hooks to modules in a model.

Registers provided hooks on __enter__ and removes them on __exit__.

model#

Root model to access modules and apply hooks to.

Type:

torch.nn.Module

modules#

List of modules to apply hooks to along with their module_path.

Type:

List[Tuple[torch.nn.Module, str]]

input_hook#

Function to apply to inputs of designated modules. Should have signature of [inputs(Any), module_path(str)] -> inputs(Any)

Type:

Callable

output_hook#

Function to apply to outputs of designated modules. Should have signature of [outputs(Any), module_path(str)] -> outputs(Any)

Type:

Callable

handles#

Handles returned from registering hooks as to be used when removing hooks on __exit__.

Type:

List[RemovableHandle]

class nnsight.intervention.InterventionProxy(node: Node)[source]#

Sub-class for Proxy that adds additional user functionality to proxies.

Examples

Saving a proxy so it is not deleted at the completion of it’s listeners is enabled with .save():

with runner.invoke('The Eiffel Tower is in the city of') as invoker:
    hidden_states = model.lm_head.input.save()
    logits = model.lm_head.output.save()

print(hidden_states.value)
print(logits.value)

This works and would output the inputs and outputs to the model.lm_head module. Had you not called .save(), calling .value would have been None.

Calling .shape on an InterventionProxy returns the shape or collection of shapes for the tensors traced through this module.

Calling .value on an InterventionProxy returns the actual populated values, updated during actual execution of the model.

property device: Collection[device]#

Property to retrieve the device of the traced proxy value or real value.

Returns:

Proxy value device or collection of devices.

Return type:

Union[torch.Size,Collection[torch.device]]

property dtype: Collection[device]#

Property to retrieve the dtype of the traced proxy value or real value.

Returns:

Proxy value dtype or collection of dtypes.

Return type:

Union[torch.Size,Collection[torch.dtype]]

property grad: InterventionProxy#

Calling denotes the user wishes to get the grad of proxy tensor and therefore we create a Proxy of that request. Only generates a proxy the first time it is references otherwise return the already set one.

Returns:

Grad proxy.

Return type:

Proxy

save() InterventionProxy[source]#

Method when called, indicates to the intervention graph to not delete the tensor values of the result.

Returns:

Save proxy.

Return type:

InterventionProxy

property shape: Collection[Size]#

Property to retrieve the shape of the traced proxy value or real value.

Returns:

Proxy value shape or collection of shapes.

Return type:

Union[torch.Size,Collection[torch.Size]]

nnsight.intervention.intervene(activations: Any, module_path: str, key: str, intervention_handler: InterventionHandler)[source]#

Entry to intervention graph. This should be hooked to all modules involved in the intervention graph.

Forms the current module_path key in the form of <module path>.<output/input> Checks the graphs argument_node_names attribute for this key. If exists, value is a list of node names to iterate through. Node args for argument type nodes should be [module_path, batch_size, batch_start, call_iter]. Checks and updates the counter for the given argument node. If counter is not ready yet continue. Using batch_size and batch_start, apply torch.narrow to tensors in activations to select only batch indexed tensors relevant to this intervention node. Sets the value of a node using the indexed values. Using torch.narrow returns a view of the tensors as opposed to a copy allowing subsequent downstream nodes to make edits to the values only in the relevant tensors, and have it update the original tensors. This both prevents interventions from effecting bathes outside their preview and allows edits to the output from downstream intervention nodes in the graph.

Parameters:
  • activations (Any) – Either the inputs or outputs of a torch module.

  • module_path (str) – Module path of the current relevant module relative to the root model.

  • key (str) – Key denoting either “input” or “output” of module.

  • intervention_handler (InterventionHandler) – Handler object that stores the intervention graph and keeps track of module call count.

Returns:

The activations, potentially modified by the intervention graph.

Return type:

Any