nnsight.intervention#
The intervention module extends the tracing module to add PyTorch specific interventions to a given computation graph. It defines its own: protocols, contexts, backends and graph primitives to achieve this.
- class nnsight.intervention.protocols.entrypoint.EntryPoint[source]#
An EntryPoint Protocol should have its value set manually outside of normal graph execution. This makes these type of Nodes special and are handled differently in a variety of cases. Subclasses EntryPoint informs those cases to handle it differently. Examples are InterventionProtocol and GradProtocol.
- class nnsight.intervention.protocols.grad.GradProtocol[source]#
Protocol which adds a backwards hook via .register_hook() to a Tensor. The hook injects the gradients into the node’s value on hook execution. Nodes created via this protocol are relative to the next time .backward() was called during tracing allowing separate .grads to reference separate backwards passes:
Uses an attachment to store number of times .backward() has been called during tracing so a given .grad hook is only value injected at the appropriate backwards pass.
- class nnsight.intervention.protocols.intervention.InterventionProtocol[source]#
- classmethod intervene(activations: Any, module_path: str, module: Module, key: str, interleaver: Interleaver)[source]#
Entry to intervention graph. This should be hooked to all modules involved in the intervention graph.
Forms the current module_path key in the form of <module path>.<output/input> Checks the graphs InterventionProtocol attachment attribute for this key. If exists, value is a list of (start:int, end:int) subgraphs to iterate through. Node args for intervention type nodes should be
[module_path, (batch_start, batch_size), iteration]
. Checks and updates the counter (number of times this module has been called for this Node) for the given intervention node. If count is not ready yet compared to the iteration, continue. Using batch_size and batch_start, apply torch.narrow to tensors in activations to select only batch indexed tensors relevant to this intervention node. Sets the value of a node using the indexed values. Using torch.narrow returns a view of the tensors as opposed to a copy allowing subsequent downstream nodes to make edits to the values only in the relevant tensors, and have it update the original tensors. This both prevents interventions from effecting bathes outside their preview and allows edits to the output from downstream intervention nodes in the graph.- Parameters:
activations (Any) – Either the inputs or outputs of a torch module.
module_path (str) – Module path of the current relevant module relative to the root model.
key (str) – Key denoting either “input” or “output” of module.
intervention_handler (InterventionHandler) – Handler object that stores the intervention graph and keeps track of module call count.
- Returns:
The activations, potentially modified by the intervention graph.
- Return type:
Any
- class nnsight.intervention.protocols.module.ApplyModuleProtocol[source]#
Protocol that references some root model, and calls its .forward() method given some input. Using .forward() vs .__call__() means it wont trigger hooks. Uses an attachment to the Graph to store the model.
- classmethod add(graph: InterventionGraph, module_path: str, *args, hook=False, **kwargs) Self [source]#
Creates and adds an ApplyModuleProtocol to the Graph. Assumes the attachment has already been added via ApplyModuleProtocol.set_module().
- Parameters:
graph (Graph) – Graph to add the Protocol to.
module_path (str) – Module path (model.module1.module2 etc), of module to apply from the root module.
- Returns:
ApplyModule Proxy.
- Return type:
- classmethod execute(node: InterventionNode) None [source]#
Executes the ApplyModuleProtocol on Node.
- Parameters:
node (Node) – ApplyModule Node.
- class nnsight.intervention.contexts.editing.EditingTracer(model: NNsight, *args, inplace: bool = False, **kwargs)[source]#
The EditingTracer exists because we want to return the edited model from __enter__ not the Tracer itself While were here we might as well force the backend to be EditingBackend
Global patching allows us to add un-traceable operations to nnsight by replacing them with ones that use the GLOBAL_TRACING_CONTEXT to add the operation to the current graph.
- class nnsight.intervention.contexts.interleaving.InterleavingTracer(model: NNsight, method: str | None = None, backend: Backend | None = None, parent: GraphType | None = None, validate: bool = False, debug: bool | None = None, **kwargs)[source]#
This is the Tracer type that actually interleaves an InterventionGraph with a PyTorch model upon execute.
- args#
Positional arguments. First is which method to interleave with and subsequent args are invoker inputs.
- Type:
Tuple[…]
- kwargs#
Keyword arguments passed to the method to interleave. These are “global” keyword arguments for our chosen methof while kwargs for a given invoker are used for preprocessing the invoker input.
- Type:
Dict[str,Any]
- class nnsight.intervention.contexts.invoker.Invoker(tracer: InterleavingTracer, *args, scan: bool = False, **kwargs)[source]#
An Invoker is meant to work in tandem with a
nnsight.intervention.contexts.InterleavingTracer
to enter input and manage intervention tracing.- tracer#
Tracer object to enter input and manage context.
- Type:
nnsight.contexts.Tracer.Tracer
- inputs#
Initially entered inputs, then post-processed inputs from model’s ._prepare_inputs(…) method.
- Type:
tuple[Any]
- scan#
If to execute the model using FakeTensor in order to update the potential sizes/dtypes of all modules’ Envoys’ inputs/outputs as well as validate things work correctly. Scanning is not free computation wise so you may want to turn this to false when running in a loop. When making interventions, you made get shape errors if scan is false as it validates operations based on shapes so for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to False.
- Type:
bool
- kwargs#
Keyword arguments passed to the model’s _prepare_inputs method.
- Type:
Dict[str,Any]
- scanning#
If currently scanning.
- Type:
bool
- class nnsight.intervention.contexts.local.LocalContext(*args, backend: ~nnsight.tracing.backends.base.Backend | None = None, parent: ~nnsight.tracing.graph.graph.GraphType | None = None, graph: ~nnsight.tracing.graph.graph.GraphType | None = None, graph_class: ~typing.Type[~nnsight.tracing.graph.graph.SubGraph] = <class 'nnsight.tracing.graph.graph.SubGraph'>, node_class: ~typing.Type[~nnsight.tracing.graph.node.NodeType] = <class 'nnsight.tracing.graph.node.Node'>, proxy_class: ~typing.Type[~nnsight.tracing.graph.proxy.ProxyType] = <class 'nnsight.tracing.graph.proxy.Proxy'>, debug: bool = False, **kwargs)[source]#
- class nnsight.intervention.contexts.local.RemoteContext(*args, backend: ~nnsight.tracing.backends.base.Backend | None = None, parent: ~nnsight.tracing.graph.graph.GraphType | None = None, graph: ~nnsight.tracing.graph.graph.GraphType | None = None, graph_class: ~typing.Type[~nnsight.tracing.graph.graph.SubGraph] = <class 'nnsight.tracing.graph.graph.SubGraph'>, node_class: ~typing.Type[~nnsight.tracing.graph.node.NodeType] = <class 'nnsight.tracing.graph.node.Node'>, proxy_class: ~typing.Type[~nnsight.tracing.graph.proxy.ProxyType] = <class 'nnsight.tracing.graph.proxy.Proxy'>, debug: bool = False, **kwargs)[source]#
- class nnsight.intervention.contexts.session.Session(model: NNsight, validate: bool = False, debug: bool | None = None, **kwargs)[source]#
A Session simply allows grouping multiple Tracers in one computation graph.
- class nnsight.intervention.contexts.tracer.InterventionTracer(*args, **kwargs)[source]#
Extension of base Tracer to add additional intervention functionality and type hinting for intervention proxies.
- class nnsight.intervention.backends.editing.EditingBackend(model: NNsight)[source]#
Backend to set the default graph to the current InterventionGraph. Assumes the final Node is an InterleavingTracer.
- class nnsight.intervention.backends.remote.RemoteBackend(model_key: str, host: str | None = None, blocking: bool = True, job_id: str | None = None, ssl: bool | None = None, api_key: str = '')[source]#
Backend to execute a context object via a remote service.
Context object must inherit from RemoteMixin and implement its methods.
- url#
Remote host url. Defaults to that set in CONFIG.API.HOST.
- Type:
str
- blocking_request(graph: Graph) Dict[int, Any] | None [source]#
Send intervention request to the remote service while waiting for updates via websocket.
- Parameters:
request (RequestModel) – Request.
- get_response() Dict[int, Any] | None [source]#
Retrieves and handles the response object from the remote endpoint.
- Raises:
Exception – If there was a status code other than 200 for the response.
- Returns:
Response.
- Return type:
(ResponseModel)
- handle_response(response: ResponseModel, graph: Graph | None = None) Dict[int, Any] | None [source]#
Handles incoming response data.
Logs the response object. If the job is completed, retrieve and stream the result from the remote endpoint. Use torch.load to decode and load the ResultModel into memory. Use the backend object’s .handle_result method to handle the decoded result.
- Parameters:
response (Any) – Json data to concert to ResponseModel
- Raises:
Exception – If the job’s status is ResponseModel.JobStatus.ERROR
- Returns:
ResponseModel.
- Return type:
ResponseModel
- non_blocking_request(graph: Graph)[source]#
Send intervention request to the remote service if request provided. Otherwise get job status.
Sets CONFIG.API.JOB_ID on initial request as to later get the status of said job.
When job is completed, clear CONFIG.API.JOB_ID to request a new job.
- Parameters:
request (RequestModel) – Request if submitting a new request. Defaults to None
- class nnsight.intervention.graph.graph.InterventionGraph(*args, model: NNsight | None = None, **kwargs)[source]#
The InterventionGraph is the special SubGraph type that handles the complex intervention operations a user wants to make during interleaving. We need to .compile() it before execution to determine how to execute interventions appropriately.
- interventions#
- grad_subgraph#
- compiled#
- call_counter#
- deferred#
- clean(start: int | None = None)[source]#
Cleans up dependencies of Node`s so their values are appropriately memory managed. Cleans all `Node`s from start to end regardless if they are on this `Graph.
- Parameters:
start (Optional[int], optional) – Node index to start cleaning up from. Defaults to None.
- cleanup() None [source]#
Because some modules may be executed more than once, and to accommodate memory management just like a loop, intervention graph sections defer updating the remaining listeners of Nodes if this is not the last time this section will be executed. If we never knew it was the last time, there may still be deferred sections after execution. These will be leftover in graph.deferred, and therefore we need to update their dependencies.
- copy(new_graph: Self = None, parent: GraphType | None = None, memo: Dict[int, NodeType] | None = None) Self [source]#
Creates a shallow copy of this SubGraph.
- Parameters:
- Returns:
New graph.
- Return type:
Self
- count(index: int, iteration: int | List[int] | slice) bool [source]#
Increments the count of times a given Intervention Node has tried to be executed and returns if the Node is ready and if it needs to be deferred.
- Parameters:
index (int) – Index of intervention node to return count for.
iteration (Union[int, List[int], slice]) – What iteration(s) this Node should be executed for.
- Returns:
If this Node should be executed on this iteration. bool: If this Node and recursive listeners should have updating their remaining listeners (and therefore their destruction) deferred.
- Return type:
bool
- class nnsight.intervention.graph.node.InterventionNode(*args, fake_value: Any | None, **kwargs)[source]#
This is the intervention extension of the base Node type.
It has a fake_value to see information about this Node’s future value before execution. It adds additional functionality to Node.prepare_inputs to handle Tensors.
- classmethod prepare_inputs(inputs: Any, device: device | None = None, fake: bool = False) Any [source]#
Override prepare_inputs to make sure
- Parameters:
inputs (Any) – _description_
device (Optional[torch.device], optional) – _description_. Defaults to None.
fake (bool, optional) – _description_. Defaults to False.
- Returns:
_description_
- Return type:
Any
- class nnsight.intervention.graph.node.ValidatingInterventionNode(*args, **kwargs)[source]#
The ValidatingInterventionNode executes its target using the fake_values of all of its dependencies to calculate a new fake_value for this node. Does not do this if the Node is detached from any graph, already has a fake_value (specified by whoever created the Node) or is a Protocol.
- class nnsight.intervention.graph.proxy.InterventionProxy(node: InterventionNode)[source]#
- property device: Collection[device]#
Property to retrieve the device of the traced proxy value or real value.
- Returns:
Proxy value device or collection of devices.
- Return type:
Union[torch.Size,Collection[torch.device]]
- property dtype: Collection[device]#
Property to retrieve the dtype of the traced proxy value or real value.
- Returns:
Proxy value dtype or collection of dtypes.
- Return type:
Union[torch.Size,Collection[torch.dtype]]
- property grad: Self#
Calling denotes the user wishes to get the grad of proxy tensor and therefore we create a Proxy of that request. Only generates a proxy the first time it is references otherwise return the already set one.
- Returns:
Grad proxy.
- Return type:
- property shape: Collection[Size]#
Property to retrieve the shape of the traced proxy value or real value.
- Returns:
Proxy value shape or collection of shapes.
- Return type:
Union[torch.Size,Collection[torch.Size]]
- class nnsight.intervention.base.NNsight(model: Module, rename: Dict[str, str] | None = None)[source]#
Main class to be implemented as a wrapper for PyTorch models wishing to gain this package’s functionality.
Class Attributes:
proxy_class (Type[InterventionProxy]): InterventionProxy like type to use as a Proxy for this Model’s inputs and outputs. Can have Model specific functionality added to a new sub-class. __methods__ (Dict[str,str]): Mapping of method name, which will open up a .trace context, and the actual method name to execute / interleave with.
For example lets say I had a method on my underlying ._model called .run that I wanted to have the NNsight interleaving functionality applied to. I could define a method on my NNsight sub-class called ._run which might look like:
def _run(self, *inputs, **kwargs): inputs, kwargs = some_preprocessing(inputs, kwargs) return self._model.run(*args, **kwargs)
I could then have my __methods__ attribute look like __methods__ = {‘run’, ‘_run’} This would allow me to do:
with model.run(...): output = model.output.save()
- _model#
Underlying torch module.
- Type:
torch.nn.Module
- _default_graph#
Intervention graph to start from when calling NNsight.trace. This is set via the editing context NNsight.edit.
- Type:
- edit(*inputs: Any, inplace: bool = False, **kwargs: Dict[str, Any]) InterleavingTracer | Any [source]#
Create a trace context with an edit backend and apply a list of edits.
The edit backend sets a default graph on an NNsight model copy which is run on future trace calls.
This operation is not inplace!
- Parameters:
inplace (bool) – If True, makes edits in-place.
- Returns:
Either the Tracer used for tracing, or the raw output if trace is False.
- Return type:
Union[Tracer, Any]
Example
from nnsight import LanguageModel
gpt2 = LanguageModel(“openai-community/gpt2)
- class ComplexModule(torch.nn.Module):
- def __init__(self):
super().__init__() self.one = WrapperModule()
- def forward(self, x):
return self.one(x)
l0 = gpt2.transformer.h[0] l0.attachment = ComplexModule()
- with gpt2.edit(“test”) as gpt2_edited:
acts = l0.output[0] l0.output[0][:] = l0.attachment(acts, hook=True)
- with gpt2.trace(MSG_prompt):
original = l0.output[0].clone().save() l0.output[0][:] *= 0.0 original_output = gpt2.output.logits.save()
- with gpt2_edited.trace(MSG_prompt):
one = l0.attachment.one.output.clone().save() l0.attachment.output *= 0.0 edited_output = gpt2.output.logits.save()
print(original_output) print(edited_output)
- get(path: str) Envoy | InterventionProxyType [source]#
Gets the Envoy/Proxy via its path.
- e.x:
model = nnsight.LanguageModel(“openai-community/gpt2”)
module = model.get(‘transformer.h.0.mlp’)
- with model.trace(“Hello”):
value = model.get(‘transformer.h.0.mlp.output’).save()
- Parameters:
path (str) – ‘.’ separated path.
- Returns:
Fetched Envoy/Proxy
- Return type:
Union[Envoy, InterventionProxyType]
- interleave(interleaver: Interleaver, *args, fn: Callable | str | None = None, **kwargs) Any [source]#
This is the point in nnsight where we finally execute the model and interleave our custom logic. Simply resolves the function and executes it given some input within the Intreleaver context. This method is on here vs on the Interleaver because some models might want to define custom interleaving behavior. For example loading real model weights before execution.
- Parameters:
interleaver (Interleaver) – Interleaver.
fn (Optional[Union[Callable, str]], optional) – Function to interleave with. Defaults to None and therefore NNsight._execute.
- Returns:
_description_
- Return type:
Any
- proxy_class#
alias of
InterventionProxy
- scan(*inputs, **kwargs) InterleavingTracer [source]#
Context just to populate fake tensor proxy values using scan and validate. Useful when looking for just the shapes of future tensors
Examples
with model.scan(" "): dim = model.module.output.shape[-1] print(dim)
- Returns:
Tracer context with Noop backend.
- Return type:
- session(backend: Backend | str | None = None, **kwargs) Session [source]#
Create a session context using a Session.
- Parameters:
backend (Backend) – Backend for this Session object.
- Returns:
Session.
- Return type:
- to(*args, **kwargs) Self [source]#
Override torch.nn.Module.to so this returns the NNSight model, not the underlying module when doing: model = model.to(…)
- Returns:
Envoy.
- Return type:
- trace(*inputs: Any, trace: bool = True, scan: bool = False, method: str | None = None, invoker_kwargs: Dict[str, Any] | None = None, backend: Backend | str | None = None, **kwargs: Dict[str, Any]) InterleavingTracer | Any [source]#
Entrypoint into the tracing and interleaving functionality nnsight provides.
In short, allows access to the future inputs and outputs of modules in order to trace what operations you would like to perform on them. This can be as simple as accessing and saving activations for inspection, or as complicated as transforming the activations and gradients in a forward pass over multiple inputs.
- Parameters:
inputs (tuple[Any]) – When positional arguments are provided directly to .trace, we assume there is only one Invoker and therefore immediately create an enter an Invoker.
trace (bool, optional) – If to open a tracing context. Otherwise immediately run the model and return the raw output. Defaults to True.
scan (bool) – Exposed invoker kwarg to scan for the provided input. No effect if there is no input.
method (Optional[str]) – String name of method to interleave with. Defaults to None and therefore NNsight._execute
invoker_args (Dict[str, Any], optional) – Keyword arguments to pass to Invoker initialization, and then downstream to the model’s .prepare_inputs(…) method. Used when giving input directly to .trace(…). Defaults to None.
kwargs (Dict[str, Any]) – Keyword arguments passed to Tracer initialization, and then downstream to the model’s execution method.
- Raises:
ValueError – If trace is False and no inputs were provided (nothing to run with)
- Returns:
Either the Tracer used for tracing, or the raw output if trace is False.
- Return type:
Union[Tracer, Any]
Examples
There are a few ways you can use
.trace(...)
depending in your use case.Lets use this extremely basic model for our examples:
import torch from collections import OrderedDict input_size = 5 hidden_dims = 10 output_size = 2 model = nn.Sequential(OrderedDict([ ('layer1', torch.nn.Linear(input_size, hidden_dims)), ('sigma1', torch.nn.Sigmoid()), ('layer2', torch.nn.Linear(hidden_dims, output_size)), ('sigma2', torch.nn.Sigmoid()), ])) example_input = torch.rand((1, input_size))
The first example has us running the model with a single example input, and saving the input and output of ‘layer2’ as well as the final output using the tracing context.
from nnsight import NNsight with NNsight(model).trace(example_input) as model: l2_input = model.layer2.input.save() l2_output = model.layer2.output.save() output = model.output.save() print(l2_input) print(l2_output) print(output)
The second example allows us to divide up multiple inputs into one batch, and scope an inner invoker context to each one. We indicate this simply by not passing and positional inputs into .trace(…). The Tracer object then expects you to enter each input via Tracer.invoke(…)
example_input2 = torch.rand((1, input_size)) with NNsight(model).trace() as model: with model.invoke(example_input): output1 = model.output.save() with model.invoke(example_input2): output2 = model.output.save() print(output1) print(output2)
- class nnsight.intervention.envoy.Envoy(module: Module, module_path: str = '', alias_path: str | None = None, rename: Dict[str, str] | None = None)[source]#
Envoy objects act as proxies for torch modules themselves within a model’s module tree in order to add nnsight functionality. Proxies of the underlying module’s output and input are accessed by .output and .input respectively.
- path#
String representing the attribute path of this Envoy’s module relative the the root model. Separated by ‘.’ e.x (‘.transformer.h.0.mlp’).
- Type:
str
- output#
Proxy object representing the output of this Envoy’s module. Reset on forward pass.
- Type:
nnsight.intervention.InterventionProxy
- inputs#
Proxy object representing the inputs of this Envoy’s module. Proxy is in the form of (Tuple[Tuple[<Positional arg>], Dict[str, <Keyword arg>]])Reset on forward pass.
- Type:
nnsight.intervention.InterventionProxy
- input#
Alias for the first positional Proxy input i.e Envoy.inputs[0][0]
- Type:
nnsight.intervention.InterventionProxy
- iter#
Iterator object allowing selection of specific .input and .output iterations of this Envoy.
- Type:
nnsight.envoy.EnvoyIterator
- _module#
Underlying torch module.
- Type:
torch.nn.Module
- _fake_outputs#
List of ‘meta’ tensors built from the outputs most recent _scan. Is list as there can be multiple shapes for a module called more than once.
- Type:
List[torch.Tensor]
- _fake_inputs#
List of ‘meta’ tensors built from the inputs most recent _scan. Is list as there can be multiple shapes for a module called more than once.
- Type:
List[torch.Tensor]
- _rename#
Optional mapping of (old name -> new name). For example to rename all gpt ‘attn’ modules to ‘attention’ you would: rename={r”attn”: “attention”} Not this does not actually change the underlying module names, just how you access its envoy. Renaming will replace Envoy.path but Envoy._path represents the pre-renamed true attribute path.
- Type:
Optional[Dict[str,str]]
- _tracer#
Object which adds this Envoy’s module’s output and input proxies to an intervention graph. Must be set on Envoys objects manually by the Tracer.
- Type:
nnsight.context.Tracer.Tracer
- all(propagate: bool = True) Envoy [source]#
By default, this modules inputs and outputs only refer to the first time its called. Use `.all()`to have .input and .output refer to all iterations.
- Returns:
Self.
- Return type:
- property input: InterventionProxyType#
Getting the first positional argument input of the model’s module.
- Returns:
Input proxy.
- Return type:
- property inputs: InterventionProxyType#
Calling denotes the user wishes to get the input of the underlying module and therefore we create a Proxy of that request. Only generates a proxy the first time it is references otherwise return the already set one.
- Returns:
Input proxy.
- Return type:
- modules(include_fn: Callable[[Envoy], bool] | None = None, names: bool = False, envoys: List | None = None) List[Envoy] [source]#
Returns all Envoys in the Envoy tree.
- Parameters:
include_fn (Callable, optional) – Optional function to be ran against all Envoys to check if they should be included in the final collection of Envoys. Defaults to None.
names (bool, optional) – If to include the name/module_path of returned Envoys along with the Envoy itself. Defaults to False.
- Returns:
Included Envoys
- Return type:
List[Envoy]
- named_modules(*args, **kwargs) List[Tuple[str, Envoy]] [source]#
Returns all Envoys in the Envoy tree along with their name/module_path.
- Parameters:
include_fn (Callable, optional) – Optional function to be ran against all Envoys to check if they should be included in the final collection of Envoys. Defaults to None.
- Returns:
Included Envoys and their names/module_paths.
- Return type:
List[Tuple[str, Envoy]]
- next(increment: int = 1) Envoy [source]#
By default, this modules inputs and outputs only refer to the first time its called. Use `.next()`to select which iteration .input an .output refer to.
- Parameters:
increment (int, optional) – How many iterations to jump. Defaults to 1.
- Returns:
Self.
- Return type:
- property output: InterventionProxyType#
Calling denotes the user wishes to get the output of the underlying module and therefore we create a Proxy of that request. Only generates a proxy the first time it is references otherwise return the already set one.
- Returns:
Output proxy.
- Return type:
- class nnsight.intervention.envoy.IterationEnvoy(envoy: Envoy)[source]#
- property input: InterventionProxyType#
Getting the first positional argument input of the model’s module.
- Returns:
Input proxy.
- Return type:
- property output: InterventionProxyType#
Calling denotes the user wishes to get the output of the underlying module and therefore we create a Proxy of that request. Only generates a proxy the first time it is references otherwise return the already set one.
- Returns:
Output proxy.
- Return type:
- class nnsight.intervention.interleaver.Interleaver(graph: InterventionGraph, batch_groups: List[Tuple[int, int]] | None = None, input_hook: Callable | None = None, output_hook: Callable | None = None, batch_size: int | None = None)[source]#
The Interleaver is responsible for executing a function involving a PyTorch model alongside a user’s custom functionality (represented by an InterventionGraph). This is called interleaving.
The InterventionGraph has information about which components (modules) of the model the user’s custom logic will interact with. As the Interleaver is a context, entering it adds the appropriate hooks to these components which act as a bridge between the model’s original computation graph and the InterventionGraph. Exiting the Interleaver removes these hooks.
- graph#
The computation graph representing the user’s custom intervention logic.
- Type:
- batch_groups#
A batch group is a section of tensor values related to a given intervention. They are a tuple of (batch_start, batch_length). So if batch group 0 was (0, 4) it means it starts at index 0 and goes until index 3. The batch index is assumed to be the first dimension of all Tensors. InterventionProtocol Nodes know which batch group they are a part of in their arguments. That value is the index into the batch_groups.
- Type:
Optional[List[Tuple[int, int]]]
- input_hook#
- Type:
Optional[Callable]
- output_hook#
- Type:
Optional[Callable]
- batch_size#
i.e If a Tensor’s first dimension isn’t batch_size, we dont need to narrow it to convert it for its batch_group. Defaults to None and therefore the sum of the last batch_group.
- Type:
Optional[int]