2023-06-20 java gpt3-0613 functi

2023-06-19  本文已影响0人  江江江123

技术点

  1. java class转json scheme

简单跑一下官方案例:

引入关于jsonschema的pom

        <dependency>
            <groupId>com.fasterxml.jackson.module</groupId>
            <artifactId>jackson-module-jsonSchema</artifactId>
            <version>2.14.2</version>
        </dependency>

实体定义

    @Data
    public static class OpenAiRequest {
        public String model;
        public List<GptMessage> messages;
        public Double temperature;
        public Integer max_tokens;
        public Boolean stream;
        public List<GptFunction> functions;
        public String function_call;

    }

    @Data
    public static class GptFunction {
        public Object parameters;
        public String description;
        public String name;
    }

    @Data
    public static class GptMessage {
        public String role;
        public String content;
        public String name;
        public FunctionCall function_call;
    }

    @Data
    public static class FunctionCall {
        public String name;
        public String arguments;
    }

    @Data
    public static class OpenAiResponse {
        public String id;
        public String object;
        public int created;
        public UsageBean usage;
        public List<ChoicesBean> choices;

        @Data
        public static class UsageBean {
            public int prompt_tokens;
            public int completion_tokens;
            public int total_tokens;
        }

        @Data
        public static class ChoicesBean {
            public int index;
            public GptMessage message;
            public String finish_reason;
        }
    }

简单定义查询天气的方法

    static String getCurrentWeather(GetWeatherInfoRequest request) {
        return "{ \"temperature\": 22, \"unit\": \"celsius\", \"description\": \"Sunny\" }";
    }

    @Data
    static class GetWeatherInfoRequest {
        @JsonProperty(required = true)
        @JsonPropertyDescription("The city and state, e.g. San Francisco, CA")
        private String location;
        private Unit unit = Unit.fahrenheit;
    }

    enum Unit {
        celsius,
        fahrenheit;
    }

定义一个将java class 转json scheme的方法

     static Object getJSONSchema(Class<?> type) {
        ObjectMapper mapper = new ObjectMapper();
        SchemaFactoryWrapper visitor = new SchemaFactoryWrapper();
        try {
            mapper.acceptJsonFormatVisitor(type, visitor);
            JsonSchema schema = visitor.finalSchema();
            String string = mapper.writerWithDefaultPrettyPrinter().writeValueAsString(schema);
            return JSON.parseObject(string);
        } catch (IOException e) {
            return null;
        }
    }

写一个简单的测试

    public static void main(String[] args) {
        OpenAiRequest openAiRequest = new OpenAiRequest();
        openAiRequest.model = "gpt-3.5-turbo-0613";
//        openAiRequest.setFunction_call("auto");
        List<GptMessage> messages = new ArrayList<>();
        GptMessage question = new GptMessage();
        question.setRole("user");
        question.setContent("What's the weather like in Boston?");
        messages.add(question);
        openAiRequest.messages = messages;
        List<GptFunction> functions = new ArrayList<>();
        GptFunction gptFunction = new GptFunction();
        gptFunction.setName("getCurrentWeather");
        gptFunction.setDescription("Get the current weather in a given location");
        gptFunction.setParameters(getJSONSchema(GetWeatherInfoRequest.class));
        functions.add(gptFunction);
        openAiRequest.setFunctions(functions);
        String jsonString = JSON.toJSONString(openAiRequest);
//        System.out.println(jsonString);
        OkHttpUtils.ResultBean resultBean = OkHttpUtils.postJson(openAiUrl, jsonString, openAiHeaders);
        if (resultBean.isError()) {
            throw new BaseException(SystemErrorType.GATEWAY_ERROR);
        }
        String result = resultBean.getResult();
//        System.out.println(result);
        OpenAiResponse response = JSON.parseObject(result, OpenAiResponse.class);
        GptMessage message = response.getChoices().get(0).getMessage();
        if (message.function_call != null) {
            FunctionCall functionCall = message.function_call;
            String funcName = functionCall.getName();
            //多个方法的话可以用反射或者 Map做调用
            String funcResponse = getCurrentWeather(JSON.parseObject(functionCall.getArguments(), GetWeatherInfoRequest.class));
            //默认返回的是nulL,在转jsonString时会丢失,造成 http code 400
            message.setContent("null");
            messages.add(message);
            GptMessage funcQuestion = new GptMessage();
            funcQuestion.setRole("function");
            funcQuestion.setContent(funcResponse);
            funcQuestion.setName(funcName);
            messages.add(funcQuestion);

            String jsonString1 = JSON.toJSONString(openAiRequest);
//            System.out.println(jsonString1);
            OkHttpUtils.ResultBean resultBean1 = OkHttpUtils.postJson(openAiUrl, jsonString1, openAiHeaders);
            if (resultBean1.isError()) {
                throw new BaseException(SystemErrorType.GATEWAY_ERROR);
            }
            System.out.println(resultBean1.getResult());
        }
    }
上一篇 下一篇

猜你喜欢

热点阅读