2023-06-20 springboot 接入gpt stre

2023-06-19  本文已影响0人  江江江123

之前用go gin部署了gpt stream,但是目前项目整体框架用的java, 为了和业务结合,还是实现了springboot版的。

为什么会有stream?

gpt是生成式的,stream模式非常适合。
一个不得不用的理由,maxToken过大的话用普通的模式容易接口超时。

技术点

还是sse,springboot api返回 SseEmitter
okhttp接收sream数据 ,需要包okhttp-sse

pom引入

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
            <exclusions>
                <exclusion>
                    <groupId>org.springframework.boot</groupId>
                    <artifactId>spring-boot-starter-tomcat</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-undertow</artifactId>
        </dependency>
        <dependency>
            <groupId>com.squareup.okhttp3</groupId>
            <artifactId>okhttp</artifactId>
            <exclusions>
                <exclusion>
                    <groupId>com.squareup.okio</groupId>
                    <artifactId>okio</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>com.squareup.okio</groupId>
            <artifactId>okio</artifactId>
            <version>${okio.version}</version>
        </dependency>
        <dependency>
            <groupId>com.squareup.okhttp3</groupId>
            <artifactId>okhttp-sse</artifactId>
        </dependency>

entity

    @Data
    public class ChatGpt35Dto {
        private String msg;
    }

    @Data
    public static class OpenAiRequest {
        public String model;
        public List<GptMessage> messages;
        public Double temperature;
        public Integer max_tokens;
        public Boolean stream;
    }

    @Data
    public static class GptMessage {
        public String role;
        public String content;
        public String name;
    }

    @Data
    public static class OpenAISteamResponse {

        public String id;
        public String object;
        public int created;
        public String model;
        public List<ChoicesBean> choices;

        @Data
        public static class ChoicesBean {
            public int index;
            public GptMessage delta;
            public String finish_reason;
        }
    }

okhttp sse utils方法定义

    //发送sse
    public static void sse(String url, String json, Map<String, String> headers, EventSourceListener listener) {
        okhttp3.RequestBody body = okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"), json);
        Request request = new Request.Builder().url(url).headers(Headers.of(headers))
                .post(body).build();
        EventSource.Factory factory = EventSources.createFactory(OkHttpUtils.getOkHttpClient());
        factory.newEventSource(request, listener);
    }
    //返回error
    public static void sendSseError(SseEmitter sseEmitter, String errorMessage) {
        try {
            sseEmitter.send(SseEmitter.event().name("error").data(errorMessage));
        } catch (Exception e) {
            e.printStackTrace();
            sseEmitter.completeWithError(e);
        }
        log.error("sse error {}", errorMessage);
        sseEmitter.complete();
    }

controller

@RestController
public class ChatGptController {
    @Autowired
    ChatGptService chatGptService;
    @GetMapping(value = {"/openai/gpt35/stream"}, produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public SseEmitter streamOpenAI(ChatGpt35Dto dto) {
        return chatGptService.gpt35Stream(dto);
    }
}

serivce

省略接口的定义,直接写一下实现

    public SseEmitter chat(ChatVirtualRoleDto dto) {
        SseEmitter sseEmitter = new SseEmitter();
        if (dto.getQuestion() == null || StringUtils.isEmpty(dto.getQuestion().trim())) {
            sendSseError(sseEmitter, "question is error");
            return sseEmitter;
        }
        OpenAiRequest openAiRequest = new OpenAiRequest();
        openAiRequest.model = "gpt-3.5-turbo";
        openAiRequest.temperature =  0.7;
        openAiRequest.stream = true;
        //todo 业务封装 message
        log.info("open ai request: {}", JSON.toJSONString(openAiRequest));
        okHttpEvent(sseEmitter, openAiRequest, openAiStreamHeaders, answer -> {
            //todo 拿到返回处理逻辑
            return 1;
        });
        return sseEmitter;
    }

    public static void okHttpEvent(SseEmitter emitter, OpenAiRequest openAiRequest, Map<String, String> openAiStreamHeaders, Function<String, Long> function) {
        StringBuilder answer = new StringBuilder();
        OkHttpUtils.sse(openAiUrl, JSON.toJSONString(openAiRequest), openAiStreamHeaders, new EventSourceListener() {
            @Override
            public void onOpen(EventSource eventSource, Response response) {
            }

            @Override
            public void onEvent(EventSource eventSource, String id, String type, String data) {
                //正常发送,收到done结束
                if ("[DONE]".equals(data)) {
                    //透传参数处理业务
                    Long historyId = function.apply(answer.toString());
                    try {
                        emitter.send(SseEmitter.event().name("stop").data(historyId));
                    } catch (IOException e) {
                    }
                    emitter.complete();
                    return;
                } else {
                    OpenAISteamResponse openAiResponse = JSONObject.parseObject(data, OpenAISteamResponse.class);
                    OpenAISteamResponse.ChoicesBean choicesBean = openAiResponse.choices.get(0);
                    //如果为空不处理,不然前端收到很多null
                    if (StringUtils.isEmpty(choicesBean.delta.content)) {
                        return;
                    } else {
                        //内容拼接
                        answer.append(choicesBean.delta.content);
                       //返回收到的消息
                        try {
                            emitter.send(SseEmitter.event().name("message").data(JSON.toJSONString(choicesBean)));
                        } catch (IOException e) {
                            emitter.complete();
                        }
                    }
                }

            }

            @Override
            public void onClosed(EventSource eventSource) {
                emitter.complete();
            }

            @Override
            public void onFailure(EventSource eventSource, Throwable t, Response response) {
                //大部分情况出现response code 429
                sendSseError(emitter, response.code() + "");
                log.error("event source failure {}", t);
            }
        });
    }

关于异常处理

大部分情况,springboot会接入统一的异常处理给前端,但是sse异常如果返回的是标准的对象而不是SseEmitter就会抛出springmvc的异常
所以要专门捕获sse中抛出的异常

@RestControllerAdvice
@Slf4j
public class ControllerAdviceConf {
    @ExceptionHandler(value = AsyncRequestTimeoutException.class)
    public void myExceptionHandler(AsyncRequestTimeoutException ex) {
        log.error("接口异常 async timeout");
        //发生异常进行日志记录,写入数据库或者其他处理,此处省略
    }
}
上一篇下一篇

猜你喜欢

热点阅读