开局一道面试题:不量化,不损失精度,如何用一张16GB的显卡推断 fp16的70B大模型?
之前在包大人在文章大模型Kaggle比赛首秀冠军方案总结提到一个黑科技,在kaggle平台上用16GB的T4显卡+32GB内存推断70B的大模型,今天把这个技术展开讲解。

我们传统的模型加载到显卡的流程是:
1.创建模型 
2.在内存中加载其权重(通常在一个叫做state_dict的对象中) 
3.在创建的模型中加载这些权重 
4.将模型移动到设备上进行推理
在这个常规步骤中,第二步需要把模型加载到RAM里,用fp16的话的显存消耗量大约是模型参数量的两倍的数量级,比如70B模型,大概需要140GB显存存放模型。要是使用32位全精度的话,需要4倍的数据数量级参数, 计算原理(1 个float32耗费4bytes,1个float16耗费 2bytes),关于参数和显存的推演,更详细的细节看之前的文章:大模型训练为什么用A100不用4090
并且在第三步,模型从RAM挪到显存,还需要一份额外的拷贝,耗费同样的大小的内存。
当你想加载更大的模型,比如BLOOM或OPT-176B(1760亿个参数)的话,按上面推演,你将需要1.4TB的CPU RAM。还是十分夸张的!而所有这些只是为了在第4步将模型移动到GPU上。
所以正常来讲,70B的大模型,至少需要4张40GB的A100来推断,一张16GB的显卡肯定是不够的,另外32GB的RAM也是远远不够的,别说显存,内存都放不下一个模型的参数。
那么我们有什么办法呢?
既然显卡放不下一个模型,那我们能否放一部分到显卡上,边推理边释放呢?
答案是肯定的,并且这个方案在 Kaggle比赛里已经实践过,在一些特定的限制资料的场景,可以用来当一个极其节省现存的推断方案。
具体的流程可以讲解为:
1.创建一个空的(例如,没有权重的)模型
2.决定每一层将要去哪里(当有多个设备可用时)
3.在内存中加载其权重的一部分
4.在空模型中加载这些权重
5.将权重移动到设备上进行推理 
6.从第3步重复,直到所有的权重都被加载

这个过程实现得益得以依赖于pytorch 1.9的一个叫meta device的玩意儿,
PyTorch 1.9引入了一种新的设备,称为元设备(meta device)。
这使我们能够创建没有任何数据附加的张量,元设备上的张量只需要一个shape,只要你在元设备上,你就可以创建任意大的张量,而不必担心CPU(或GPU)的RAM够不够。
比如下面的代码,内存不够的话就会崩掉
import
 torch

large_tensor = torch.randn(
100000
100000
)

这个大张量需要4 * 10**10字节(默认精度是FP32,所以张量的每个元素占用4字节),因此需要40GB的RAM。然而,在元设备上执行相同的操作就可以正常运行:
import
 torch

large_tensor = torch.randn(
100000
100000
, device=
"meta"
)

这个张量没有关联的数据,只有一个形状。你可以直接在元设备上实例化一个模型:
large_model = torch.nn.Linear(
100000
100000
, device=
"meta"
)

但是对于现成的模型来说,这种语法需要你重写所有的建模代码,以便每个模型的子部分都接受并传递一个设备关键字参数。由于这对Transformers库的预训练模型来说不切实际,accelerate库有一个context manager,整合了meta device可以实例化一个空模型。
# Load meta model (no memory used)
with
 init_empty_weights():

    self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=
True
)

    self.model.tie_weights()

这一步很关键,我们知道每个权重的形状,因此我们可以知道一旦我们完全加载预训练的张量,它们将消耗多少内存。因此,我们可以决定如何在CPU和GPU之间分割我们的模型。
除此之外,定义了两个关键的方法,分别是load_layer_to_cpu,负责把 权重从disk挪到CPU,另外一个是move_layer_to_device,负责把权重从cpu挪到显卡。还有一个释放显存的方法clean_memory,负责清空显存。
defload_layer_to_cpu(self, layer_name):
    self.weights_loader.set_state_dict(layer_name, self.device)

    state_dict = self.weights_loader.get_state_dict(self.device)

if"value_head.weight"in
 state_dict:

        state_dict = {
"lm_head.weight"
 : state_dict[
"value_head.weight"
]}

return
 state_dict


defmove_layer_to_device(self, state_dict):
for
 param_name, param 
in
 state_dict.items():

assert
 param.dtype != torch.int8, 
"int8 not supported (need to add fp16_statistics)"
        set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)


defclean_memory():
    gc.collect()

    ctypes.CDLL(
"libc.so.6"
).malloc_trim(
0
)

    torch.cuda.empty_cache()

好了,对过程大概弄懂了就可以看完整的代码了。注意,下面的代码也包含了题目设定里特定的prefix和suffix,正常的推理可以忽略相关的逻辑,仅保留一个prompt即可。下面展示完整的代码
# For LLM
from
 transformers 
import
 AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel

from
 accelerate 
import
 init_empty_weights

from
 accelerate.utils.modeling 
import
 set_module_tensor_to_device

from
 safetensors.torch 
import
 load_file

from
 optimum.bettertransformer 
import
 BetterTransformer


N_BATCHES = 
3
MAX_LENGTH = 
4096

defclean_memory():
    gc.collect()

    ctypes.CDLL(
"libc.so.6"
).malloc_trim(
0
)

    torch.cuda.empty_cache()



# Class for sharded llama
classShardedLlama:
def__init__(self, checkpoint_path, weights_loader, device="cuda:0", dtype=torch.float16):

# Save parameters
        self.checkpoint_path = Path(checkpoint_path)

        self.weights_loader = weights_loader

        self.device = device 

        self.dtype = dtype


# Create model
        self.config = AutoConfig.from_pretrained(self.checkpoint_path)   

        self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)

        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.tokenizer.padding_side = 
"right"
        self.init_model()

        self.layer_names = [
"model.embed_tokens"
] + [
f"model.layers.{i}"for
 i 
in
 range(len(self.model.model.layers))] + [
"model.norm"
"value_head"
]


definit_model(self):

# Load meta model (no memory used)
with
 init_empty_weights():

            self.model = AutoModelForCausalLM.from_config(self.config)

            self.model.lm_head = torch.nn.Linear(
8192
8
, bias=
False
# originally 32k
            self.model.eval()

            self.model = BetterTransformer.transform(self.model) 
# enable flash attention
            self.model.tie_weights()


        self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm, self.model.lm_head]


# Move buffers to device (note that much GPU memory used)
for
 buffer_name, buffer 
in
 self.model.named_buffers():

            set_module_tensor_to_device(self.model, buffer_name, self.device, value=buffer, dtype=self.dtype)


defload_layer_to_cpu(self, layer_name):
        self.weights_loader.set_state_dict(layer_name, self.device)

        state_dict = self.weights_loader.get_state_dict(self.device)

if"value_head.weight"in
 state_dict:

            state_dict = {
"lm_head.weight"
 : state_dict[
"value_head.weight"
]}

return
 state_dict


defmove_layer_to_device(self, state_dict):
for
 param_name, param 
in
 state_dict.items():

assert
 param.dtype != torch.int8, 
"int8 not supported (need to add fp16_statistics)"
            set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype)


def__call__(self, inputs):
# inputs = [(prefix, suffix), ...] with prefix.shape[0] = 1 and suffix.shape[0] = 5

# Reboot the model to make sure buffers are loaded and memory is clean
del
 self.model

        clean_memory()

        self.init_model()


# Send batch to device
        batch = [(prefix.to(self.device), suffix.to(self.device)) 
for
 prefix, suffix 
in
 inputs]

        n_suffixes = len(batch[
0
][
1
])

        suffix_eos = [(suffix != self.tokenizer.pad_token_id).sum(
1
) - 
1for
 _, suffix 
in
 inputs]


# Create attention mask for the largest input, and position ids to use KV cache
        attention_mask = torch.ones(MAX_LENGTH, MAX_LENGTH)

        attention_mask = attention_mask.triu(diagonal=
1
)[
None
None
, ...] == 
0
        attention_mask = attention_mask.to(self.device)

        position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[
None
, :]


with
 ThreadPoolExecutor() 
as
 executor, torch.inference_mode():


# Load first layer
            future = executor.submit(self.load_layer_to_cpu, 
"model.embed_tokens"
)


for
 i, (layer_name, layer) 
in
 tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.device, total=len(self.layers)):


# Load current layer and prepare next layer
                state_dict = future.result()

if
 (i + 
1
) < len(self.layer_names):

                    future = executor.submit(self.load_layer_to_cpu, self.layer_names[i + 
1
])

                self.move_layer_to_device(state_dict)


# Run layer
for
 j, (prefix, suffix) 
in
 enumerate(batch):

if
 layer_name == 
"model.embed_tokens"
:

                        batch[j] = (layer(prefix), layer(suffix))

elif
 layer_name == 
"model.norm"
:

# Only keep the last token at this point
                        batch[j] = (
None
, layer(suffix[torch.arange(n_suffixes), suffix_eos[j]][:, 
None
]))

elif
 layer_name == 
"value_head"
:

                        batch[j] = layer(suffix)[:, 
0
].mean(
1
).detach().cpu().numpy()

else
:

# Run prefix
                        len_p, len_s = prefix.shape[
1
], suffix.shape[
1
]

                        new_prefix, (k_cache, v_cache) = layer(prefix, use_cache=
True
, attention_mask=attention_mask[:, :, -len_p:, -len_p:])


# Run suffix
                        pos = position_ids[:, len_p:len_p + len_s].expand(n_suffixes, 
-1
)

                        attn = attention_mask[:, :, -len_s:, -len_p - len_s:].expand(n_suffixes, 
-1
-1
-1
)

                        kv_cache = (k_cache.expand(n_suffixes, 
-1
-1
-1
), v_cache.expand(n_suffixes, 
-1
-1
-1
))

                        new_suffix = layer(suffix, past_key_value=kv_cache, position_ids=pos, attention_mask=attn)[
0
]

                        batch[j] = (new_prefix, new_suffix)


# Remove previous layer from memory (including buffers)
                layer.to(
"meta"
)

                clean_memory() 
# proposed by CPMP

# Get scores
return
 batch





defrun_model(device, df, weights_loader):
    model = ShardedLlama(checkpoint_path, weights_loader, device=device)

    f = partial(get_tokens, tokenizer=model.tokenizer)

    inputs = df.apply(f, axis=
1
).values

    batches = np.array_split(inputs, N_BATCHES)

    outputs = []

for
 i, batch 
in
 enumerate(batches):

        outputs += model(batch)

return
 outputs


完整的代码可以参考链接:https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag
相关文章
其他精彩文章翻阅公众号历史文章
包包算法笔记是包大人在班车通勤上,进行知识,职业,经验分享的地方。最白的话讲专业的知识。

回复“刷题”获取高效刷题经验,回复“面试”获取算法校招面试宝典,回复“大模型”获取大模型技术资料。
最近高强度写大模型原创!

进讨论群的同学,加微信号logits,备注进群
继续阅读
阅读原文