What JPA + Hibernate data type should I use to support the vector extension in a PostgreSQL database?

huangapple go评论64阅读模式
英文:

What JPA + Hibernate data type should I use to support the vector extension in a PostgreSQL database?

问题

JPA + Hibernate数据类型应该使用PostgreSQL数据库中的vector扩展来支持,以便允许您在JPA实体中创建嵌入向量。

英文:

What JPA + Hibernate data type should I use to support the vector extension in a PostgreSQL database, so that it allows me to create embeddings using a JPA Entity?

CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3));

pgvector

答案1

得分: 3

你可以使用vladmihalcea的Hibernate类型将向量类型转换为List,以便可以使用JpaRepository保存或查询。

  1. pom.xml文件中添加依赖:

    <dependency>
      <groupId>io.hypersistence</groupId>
      <artifactId>hypersistence-utils-hibernate-55</artifactId>
      <version>3.5.0</version>
    </dependency>
    
  2. 创建Item类:

    import com.fasterxml.jackson.annotation.JsonInclude;
    import io.hypersistence.utils.hibernate.type.json.JsonType;
    import lombok.Data;
    import lombok.NoArgsConstructor;
    import org.hibernate.annotations.Type;
    import org.hibernate.annotations.TypeDef;
    
    import javax.persistence.*;
    import java.util.List;
    
    @Data
    @NoArgsConstructor
    @Entity
    @Table(name = "items")
    @JsonInclude(JsonInclude.Include.NON_NULL)
    @TypeDef(name = "json", typeClass = JsonType.class)
    public class Item {
        @Id
        @GeneratedValue(strategy = GenerationType.IDENTITY)
        private Long id;
    
        @Type(type = "json")
        @Column(columnDefinition = "vector")
        private List<Double> embedding;
    }
    
  3. 创建支持保存和查找的JpaRepository接口。您可以使用原生SQL编写自定义的findNearestNeighbors方法。

    import org.springframework.data.jpa.repository.JpaRepository;
    
    public interface ItemRepository extends JpaRepository<Item, Long> {
    
        // 通过向量查找最近的邻居,例如 value = "[1,2,3]"
        // 这也可以工作,强制转换等同于postgresql中的::运算符
        //@Query(nativeQuery = true, value = "SELECT * FROM items ORDER BY embedding <-> cast(? as vector) LIMIT 5")
        @Query(nativeQuery = true, value = "SELECT * FROM items ORDER BY embedding <-> ?\\:\\:vector LIMIT 5")
        List<Item> findNearestNeighbors(String value);
    
        // 通过同一表中的记录查找最近的邻居
        @Query(nativeQuery = true, value = "SELECT * FROM items WHERE id != :id ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = :id) LIMIT 5")
        List<Item> findNearestNeighbors(Long id);
    }
    
  4. 测试创建、查询和findNearestNeighbors:

    @Autowired
    private ItemRepository itemRepository;
    
    @Test
    @Rollback(false)
    @Transactional
    public void createItem() {
        Item item = new Item();
        Random rand = new Random();
        List<Double> embedding = new ArrayList<>();
        for (int i = 0; i < 3; i++)
            embedding.add(rand.nextDouble());
        item.setEmbedding(embedding);
        itemRepository.save(item);
    }
    
    @Test
    public void loadItems() {
        final List<Item> items = itemRepository.findAll();
        System.out.println(items);
    }
    
    @Test
    public void findNearestNeighbors() {
        final String value = "[0.1, 0.2, 0.3]";
        final List<Item> items = itemRepository.findNearestNeighbors(value);
        System.out.println(items);
    }
    
英文:

You can use vladmihalcea Hibernate types to convert a vector type to List&lt;Double>, so it is possible to save or query with JpaRepository.

  1. Add a dependency to the pom.xml file:

    &lt;dependency&gt;
      &lt;groupId&gt;io.hypersistence&lt;/groupId&gt;
      &lt;artifactId&gt;hypersistence-utils-hibernate-55&lt;/artifactId&gt;
      &lt;version&gt;3.5.0&lt;/version&gt;
    &lt;/dependency&gt;
    
  2. Create the Item class:

    import com.fasterxml.jackson.annotation.JsonInclude;
    import io.hypersistence.utils.hibernate.type.json.JsonType;
    import lombok.Data;
    import lombok.NoArgsConstructor;
    import org.hibernate.annotations.Type;
    import org.hibernate.annotations.TypeDef;
    
    import javax.persistence.*;
    import java.util.List;
    
    @Data
    @NoArgsConstructor
    @Entity
    @Table(name = &quot;items&quot;)
    @JsonInclude(JsonInclude.Include.NON_NULL)
    @TypeDef(name = &quot;json&quot;, typeClass = JsonType.class)
    public class Item {
        @Id
        @GeneratedValue(strategy = GenerationType.IDENTITY)
        private Long id;
    
        @Type(type = &quot;json&quot;)
        @Column(columnDefinition = &quot;vector&quot;)
        private List&lt;Double&gt; embedding;
    }
    
  3. Create a JpaRepository interface that supports save and find. You can write custom findNearestNeighbors methods with native SQL

    import org.springframework.data.jpa.repository.JpaRepository;
    
    public interface ItemRepository extends JpaRepository&lt;Item, Long&gt; {
    
        // Find nearest neighbors by a vector, for example value = &quot;[1,2,3]&quot;
        // This also works, cast is equals to the :: operator in postgresql
        //@Query(nativeQuery = true, value = &quot;SELECT * FROM items ORDER BY embedding &lt;-&gt; cast(? as vector) LIMIT 5&quot;)
        @Query(nativeQuery = true, value = &quot;SELECT * FROM items ORDER BY embedding &lt;-&gt; ? \\:\\:vector LIMIT 5&quot;)
        List&lt;Item&gt; findNearestNeighbors(String value);
    
        // Find nearest neighbors by a record in the same table
        @Query(nativeQuery = true, value = &quot;SELECT * FROM items WHERE id != :id ORDER BY embedding &lt;-&gt; (SELECT embedding FROM items WHERE id = :id) LIMIT 5&quot;)
        List&lt;Item&gt; findNearestNeighbors(Long id);
    }
    
  4. Test create, query and findNearestNeighbors:

    @Autowired
    private ItemRepository itemRepository;
    
    @Test
    @Rollback(false)
    @Transactional
    public void createItem() {
        Item item = new Item();
        Random rand = new Random();
        List&lt;Double&gt; embedding = new ArrayList&lt;&gt;();
        for (int i = 0; i &lt; 3; i++)
            embedding.add(rand.nextDouble());
        item.setEmbedding(embedding);
        itemRepository.save(item);
    }
    
    @Test
    public void loadItems() {
        final List&lt;Item&gt; items = itemRepository.findAll();
        System.out.println(items);
    }
    
    @Test
    public void findNearestNeighbors() {
        final String value = &quot;[0.1, 0.2, 0.3]&quot;;
        final List&lt;Item&gt; items = itemRepository.findNearestNeighbors(value);
        System.out.println(items);
    }
    

huangapple
  • 本文由 发表于 2023年6月26日 13:30:51
  • 转载请务必保留本文链接:https://go.coder-hub.com/76553746.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定