150 lines
6.6 KiB
Python
150 lines
6.6 KiB
Python
|
|
# -------------------------------------------------------------------------
|
||
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||
|
|
# Licensed under the MIT License. See License.txt in the project root for
|
||
|
|
# license information.
|
||
|
|
# --------------------------------------------------------------------------
|
||
|
|
|
||
|
|
import logging
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class PastKeyValuesHelper:
|
||
|
|
"""Helper functions to process past key values for encoder-decoder model"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_past_names(num_layers, present: bool = False):
|
||
|
|
past_self_names = []
|
||
|
|
past_cross_names = []
|
||
|
|
for i in range(num_layers):
|
||
|
|
past_self_names.extend(
|
||
|
|
[f"present_key_self_{i}", f"present_value_self_{i}"]
|
||
|
|
if present
|
||
|
|
else [f"past_key_self_{i}", f"past_value_self_{i}"]
|
||
|
|
)
|
||
|
|
past_cross_names.extend(
|
||
|
|
[f"present_key_cross_{i}", f"present_value_cross_{i}"]
|
||
|
|
if present
|
||
|
|
else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
|
||
|
|
)
|
||
|
|
return past_self_names + past_cross_names
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def group_by_self_or_cross(present_key_values):
|
||
|
|
"""Split present state from grouped by layer to grouped by self/cross attention.
|
||
|
|
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
|
||
|
|
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
|
||
|
|
|
||
|
|
"""
|
||
|
|
present_self = []
|
||
|
|
present_cross = []
|
||
|
|
for _i, present_layer_i in enumerate(present_key_values):
|
||
|
|
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
|
||
|
|
(
|
||
|
|
present_key_self,
|
||
|
|
present_value_self,
|
||
|
|
present_key_cross,
|
||
|
|
present_value_cross,
|
||
|
|
) = present_layer_i
|
||
|
|
present_self.extend([present_key_self, present_value_self])
|
||
|
|
present_cross.extend([present_key_cross, present_value_cross])
|
||
|
|
return present_self, present_cross
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def group_by_layer(past, num_layers):
|
||
|
|
"""Reorder past state from grouped by self/cross attention to grouped by layer.
|
||
|
|
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
|
||
|
|
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
||
|
|
"""
|
||
|
|
assert len(past) == 4 * num_layers
|
||
|
|
return tuple(
|
||
|
|
[
|
||
|
|
past[2 * i],
|
||
|
|
past[2 * i + 1],
|
||
|
|
past[2 * num_layers + 2 * i],
|
||
|
|
past[2 * num_layers + 2 * i + 1],
|
||
|
|
]
|
||
|
|
for i in range(num_layers)
|
||
|
|
)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]):
|
||
|
|
"""Categorize present_key_values from self and cross attention to layer by layer.
|
||
|
|
|
||
|
|
Reorder past state from grouped by self/cross attention to grouped by layer.
|
||
|
|
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
|
||
|
|
past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
|
||
|
|
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
||
|
|
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
||
|
|
|
||
|
|
Args:
|
||
|
|
present_key_values: From past_key_values of a model (group by self and cross attention)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
past_tuples: present key and values grouped by layer.
|
||
|
|
"""
|
||
|
|
past_tuples = ()
|
||
|
|
half_idx = len(past_key_values) // 2
|
||
|
|
for i in range(len(past_key_values) // 4):
|
||
|
|
idx = 2 * i
|
||
|
|
past_tuples += (
|
||
|
|
(
|
||
|
|
past_key_values[idx],
|
||
|
|
past_key_values[idx + 1],
|
||
|
|
past_key_values[half_idx + idx],
|
||
|
|
past_key_values[half_idx + idx + 1],
|
||
|
|
),
|
||
|
|
)
|
||
|
|
return past_tuples
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False):
|
||
|
|
"""Categorize present_key_values into self and cross attention.
|
||
|
|
|
||
|
|
Split present state from grouped by layer to grouped by self/cross attention.
|
||
|
|
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
||
|
|
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
|
||
|
|
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
|
||
|
|
(past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
|
||
|
|
|
||
|
|
Args:
|
||
|
|
present_key_values: From past_key_values of a model (group by layer)
|
||
|
|
concat: If concat self attention with cross attention key/value to return
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
present_self (Tuple[torch.Tensor]): present key and values from self attention
|
||
|
|
present_cross (Tuple[torch.Tensor]): present key and values from cross attention
|
||
|
|
"""
|
||
|
|
present_self: list[torch.Tensor] = []
|
||
|
|
present_cross: list[torch.Tensor] = []
|
||
|
|
for _, present_layer_i in enumerate(present_key_values):
|
||
|
|
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
|
||
|
|
present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
|
||
|
|
present_self.extend([present_key_self, present_value_self])
|
||
|
|
present_cross.extend([present_key_cross, present_value_cross])
|
||
|
|
if concat:
|
||
|
|
return present_self + present_cross
|
||
|
|
else:
|
||
|
|
return present_self, present_cross
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True):
|
||
|
|
"""Process input names of model wrapper.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
past_key_values: Consider `self` and `cross` past_key_values
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
names (List[string]): input names
|
||
|
|
"""
|
||
|
|
names = []
|
||
|
|
num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
|
||
|
|
prefix = "past_" if not encoder else "present_"
|
||
|
|
for i in range(num_layers):
|
||
|
|
names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
|
||
|
|
for i in range(num_layers):
|
||
|
|
names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
|
||
|
|
return names
|