|
--- |
|
license: apache-2.0 |
|
language: |
|
- zh |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
专为中文法律垂直领域的校阅模型 |
|
|
|
训练数据如下 |
|
- **Dataset by:** [correct_law](https://huggingface.co./datasets/lzy510016411/correct_law/) |
|
|
|
|
|
### Model Sources [optional] |
|
|
|
使用qwen1.5 14b作为基础,进行lora训练而成,使用的llamafactory框架 |
|
|
|
训练参数如下: |
|
|
|
```yaml |
|
quantization_bit: 4 |
|
|
|
stage: sft |
|
do_train: true |
|
finetuning_type: lora |
|
lora_target: q_proj,gate_proj,v_proj,up_proj,k_proj,o_proj,down_proj |
|
lora_rank: 32 |
|
lora_alpha: 64 |
|
lora_dropout: 0.05 |
|
|
|
ddp_timeout: 180000000 |
|
deepspeed: examples/deepspeed/ds_z2_config.json |
|
|
|
dataset: 这里自己设定,我们还加入了alpha之类的通用qa,但数量较少 |
|
template: qwen |
|
cutoff_len: 512 |
|
max_length: 512 |
|
overwrite_cache: true |
|
preprocessing_num_workers: 16 |
|
|
|
output_dir: saves/qwen-14b/lora/sft |
|
mix_strategy: interleave |
|
logging_steps: 5 |
|
save_steps: 500 |
|
plot_loss: true |
|
save_total_limit: 20 |
|
overwrite_output_dir: true |
|
|
|
flash_attn: fa2 |
|
per_device_train_batch_size: 2 |
|
gradient_accumulation_steps: 8 |
|
learning_rate: 0.0001 |
|
num_train_epochs: 3 |
|
weight_decay: 0.01 |
|
optim: adamw_torch |
|
#8bit优化器似乎存在问题 |
|
lr_scheduler_type: cosine |
|
warmup_steps: 0.01 |
|
bf16: true |
|
|
|
load_best_model_at_end: true |
|
val_size: 0.001 |
|
per_device_eval_batch_size: 1 |
|
evaluation_strategy: steps |
|
eval_steps: 250 |
|
``` |
|
|
|
## Uses |
|
|
|
建议使用vllm的openai服务,启动脚本如下 |
|
|
|
```sh |
|
|
|
nohup /root/miniconda3/envs/py310/bin/python -m vllm.entrypoints.openai.api_server \ |
|
--model 模型路径 \ |
|
--port 7777 \ |
|
--tensor-parallel-size 2 \ |
|
--gpu-memory-utilization 0.80 \ |
|
--swap-space 8 \ |
|
--max-model-len 512 \ |
|
--max-log-len 512 \ |
|
--enable-lora \ |
|
--max-lora-rank 32 \ |
|
--max-cpu-loras 8 \ |
|
--max-num-seqs 8 \ |
|
--lora-modules correct=checkpoint-20000(这里填写lora模型的路径) >/mnt/Models/base_llm.log 2>&1 & |
|
|
|
``` |
|
|
|
调用方法如下: |
|
|
|
``` python |
|
from openai import OpenAI |
|
# Set OpenAI's API key and API base to use vLLM's API server. |
|
openai_api_key = "EMPTY" |
|
openai_api_base = "http://192.168.110.171:6660/v1" |
|
|
|
client = OpenAI( |
|
api_key=openai_api_key, |
|
base_url=openai_api_base, |
|
) |
|
source_data="""中国公民出境如境,应当向出入境边防检查机关交验本人的护照或者其他旅行证件等出境入境证件,履行规定的手续,经查验准许,方可出境入境。 |
|
具备条件的口岸,出入境边防检查机关应当为中国公民出境入境提供专用通道等便利措施。""" |
|
|
|
chat_response = client.chat.completions.create( |
|
model="correct-lora", |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": "对下列文本进行纠错:\n\n%s"%source_data}, |
|
], |
|
temperature=0.1 |
|
) |
|
|
|
if chat_response: |
|
content = chat_response.choices[0].message.content |
|
new_content=content[9:] |
|
if new_content==source_data or content=='该文本没有错误': |
|
print('该文本没有错误') |
|
else: |
|
print(content) |
|
else: |
|
print("Error:", chat_response.status_code) |
|
``` |
|
|
|
|
|
### Direct Use |
|
|
|
也可以直接用transformers加载,这里就不多赘述了 |