nnsight.intervention#

This module contains logic to interleave a computation graph (an intervention graph) with the computation graph of a model.

The InterventionProxy class extends the functionality of a base nnsight.tracing.Proxy.Proxy object and makes it easier for users to interact with.

intervene() is the entry hook into the models computation graph in order to interleave an intervention graph.

The HookModel provides a context manager for adding input and output hooks to modules and removing them upon context exit.

class nnsight.intervention.HookHandler(model: Module, module_keys: List[str], input_hook: Callable = None, output_hook: Callable = None)[source]#

Context manager that applies input and/or output hooks to modules in a model.

Registers provided hooks on __enter__ and removes them on __exit__.

model#

Root model to access modules and apply hooks to.

Type:

torch.nn.Module

modules#

List of modules to apply hooks to along with their module_path.

Type:

List[Tuple[torch.nn.Module, str]]

input_hook#

Function to apply to inputs of designated modules. Should have signature of [inputs(Any), module_path(str)] -> inputs(Any)

Type:

Callable

output_hook#

Function to apply to outputs of designated modules. Should have signature of [outputs(Any), module_path(str)] -> outputs(Any)

Type:

Callable

handles#

Handles returned from registering hooks as to be used when removing hooks on __exit__.

Type:

List[RemovableHandle]

class nnsight.intervention.InterventionHandler(graph: Graph, batch_groups: List[Tuple[int, int]], batch_size: int)[source]#

Object passed to InterventionProtocol.intervene to store information about the current interleaving execution run.

Like the Intervention Graph, the total batch size that is being executed, and a counter for how many times an Intervention node has been attempted to be executed.

count(name: str) int[source]#

Increments the count of times a given Intervention Node has tried to be executed and returns the count.

Parameters:

name (str) – Name of intervention node to return count for.

Returns:

Count.

Return type:

int

class nnsight.intervention.InterventionProtocol[source]#

Primary Protocol that handles tracking and injecting inputs and outputs from a torch model into the overall intervention Graph. Uses an attachment on the Graph to store the names of nodes that need to be injected with data from inputs or outputs of modules.

classmethod add(graph: Graph, proxy_value: Any, args: List[Any] = None, kwargs: Dict[str, Any] = None) Proxy[source]#

Adds an InterventionProtocol Node to a Graph.

Parameters:
  • graph (Graph) – Graph to add to.

  • module_path (str) – Module path of data this Node depends on (ex. model.module1.module2.output)

  • proxy_value (Any) – Proxy value.

  • args (List[Any], optional) – Args. Defaults to None.

  • kwargs (Dict[str, Any], optional) – Kwargs. Defaults to None.

Returns:

_description_

Return type:

Proxy

classmethod get_interventions(graph: Graph) Dict[source]#

Returns mapping from module_paths to InterventionNode names added to the given Graph.

Parameters:

graph (Graph) – Graph.

Returns:

Interventions.

Return type:

Dict

classmethod intervene(activations: Any, module_path: str, key: str, intervention_handler: InterventionHandler)[source]#

Entry to intervention graph. This should be hooked to all modules involved in the intervention graph.

Forms the current module_path key in the form of <module path>.<output/input> Checks the graphs InterventionProtocol attachment attribute for this key. If exists, value is a list of node names to iterate through. Node args for intervention type nodes should be [module_path, batch_size, batch_start, call_iter]. Checks and updates the counter for the given intervention node. If counter is not ready yet continue. Using batch_size and batch_start, apply torch.narrow to tensors in activations to select only batch indexed tensors relevant to this intervention node. Sets the value of a node using the indexed values. Using torch.narrow returns a view of the tensors as opposed to a copy allowing subsequent downstream nodes to make edits to the values only in the relevant tensors, and have it update the original tensors. This both prevents interventions from effecting bathes outside their preview and allows edits to the output from downstream intervention nodes in the graph.

Parameters:
  • activations (Any) – Either the inputs or outputs of a torch module.

  • module_path (str) – Module path of the current relevant module relative to the root model.

  • key (str) – Key denoting either “input” or “output” of module.

  • intervention_handler (InterventionHandler) – Handler object that stores the intervention graph and keeps track of module call count.

Returns:

The activations, potentially modified by the intervention graph.

Return type:

Any

classmethod style() Dict[str, Any][source]#

Visualization style for this protocol node.

Returns:

dictionary style.

Return type:

  • Dict

class nnsight.intervention.InterventionProxy(node: Node)[source]#

Sub-class for Proxy that adds additional user functionality to proxies.

Examples

Saving a proxy so it is not deleted at the completion of it’s listeners is enabled with .save():

with model.trace('The Eiffel Tower is in the city of'):
    hidden_states = model.lm_head.input.save()
    logits = model.lm_head.output.save()

print(hidden_states)
print(logits)
property device: Collection[device]#

Property to retrieve the device of the traced proxy value or real value.

Returns:

Proxy value device or collection of devices.

Return type:

Union[torch.Size,Collection[torch.device]]

property dtype: Collection[device]#

Property to retrieve the dtype of the traced proxy value or real value.

Returns:

Proxy value dtype or collection of dtypes.

Return type:

Union[torch.Size,Collection[torch.dtype]]

property grad: InterventionProxy#

Calling denotes the user wishes to get the grad of proxy tensor and therefore we create a Proxy of that request. Only generates a proxy the first time it is references otherwise return the already set one.

Returns:

Grad proxy.

Return type:

Proxy

save() InterventionProxy[source]#

Method when called, indicates to the intervention graph to not delete the tensor values of the result.

Returns:

Proxy.

Return type:

InterventionProxy

property shape: Collection[Size]#

Property to retrieve the shape of the traced proxy value or real value.

Returns:

Proxy value shape or collection of shapes.

Return type:

Union[torch.Size,Collection[torch.Size]]

stop() InterventionProxy[source]#

Method when called, indicates to the intervention graph to stop the execution of the model after this Proxy/Node is completed..

Returns:

Proxy.

Return type:

InterventionProxy

update(value: Node | Any) InterventionProxy[source]#

Updates the value of the Proxy via the creation of the UpdateProtocol node.

Parameters:

value (-) – New proxy value.

Returns:

Proxy.

Return type:

InterventionProxy