Adds code for image generation

This commit is contained in:
2024-10-05 17:07:18 +01:00
parent cb954ea5db
commit bbe5433d83
6 changed files with 216 additions and 0 deletions

View File

@@ -11,6 +11,7 @@ use crate::prelude::{
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
TextEmbeddingResponse,
};
use crate::types::{PredictImageRequest, PredictImageResponse};
use crate::{prelude::Part, token_provider::TokenProvider};
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
@@ -252,4 +253,27 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tracing::debug!("count_tokens response: {:?}", txt_json);
Ok(serde_json::from_str(&txt_json)?)
}
pub async fn predict_image(
&self,
request: &PredictImageRequest,
model: &str,
) -> Result<PredictImageResponse> {
let endpoint_url = format!(
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.api_endpoint, self.project_id, self.location_id, model,
);
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let resp = self
.client
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request)
.send()
.await?;
let txt_json = resp.text().await?;
Ok(serde_json::from_str(&txt_json)?)
}
}